comparison service/mqtt_to_rdf/inference.py @ 1648:3059f31b2dfa

more performance work
author drewp@bigasterisk.com
date Fri, 17 Sep 2021 11:10:18 -0700
parents 5403c6343fa4
children bb5d2b5370ac
comparison
equal deleted inserted replaced
1647:34eb87f68ab8 1648:3059f31b2dfa
59 prev: Optional['StmtLooper'] 59 prev: Optional['StmtLooper']
60 workingSet: ReadOnlyWorkingSet 60 workingSet: ReadOnlyWorkingSet
61 parent: 'Lhs' # just for lhs.graph, really 61 parent: 'Lhs' # just for lhs.graph, really
62 62
63 def __repr__(self): 63 def __repr__(self):
64 return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' 64 return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})'
65 65
66 def __post_init__(self): 66 def __post_init__(self):
67 self._shortId = next(_stmtLooperShortId) 67 self._shortId = next(_stmtLooperShortId)
68 self._myWorkingSetMatches = self._myMatches(self.workingSet) 68 self._myWorkingSetMatches = self._myMatches(self.workingSet)
69 69
70 self._current = CandidateBinding({}) 70 self._current = CandidateBinding({})
71 self._pastEnd = False 71 self._pastEnd = False
72 self._seenBindings: List[CandidateBinding] = [] 72 self._seenBindings: List[CandidateBinding] = []
73
74 log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})')
75
73 self.restart() 76 self.restart()
74 77
75 def _myMatches(self, g: Graph) -> List[Triple]: 78 def _myMatches(self, g: Graph) -> List[Triple]:
76 template = stmtTemplate(self.lhsStmt) 79 template = stmtTemplate(self.lhsStmt)
77 80
117 120
118 for stmt in augmentedWorkingSet: 121 for stmt in augmentedWorkingSet:
119 try: 122 try:
120 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) 123 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
121 except Inconsistent: 124 except Inconsistent:
122 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') 125 log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings')
123 continue 126 continue
124 127
125 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') 128 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
126 if outBinding not in self._seenBindings: 129 if outBinding not in self._seenBindings:
127 self._seenBindings.append(outBinding.copy()) 130 self._seenBindings.append(outBinding.copy())
167 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: 170 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
168 outBinding = self._prevBindings().copy() 171 outBinding = self._prevBindings().copy()
169 for rt, ct in zip(self.lhsStmt, newStmt): 172 for rt, ct in zip(self.lhsStmt, newStmt):
170 if isinstance(rt, (Variable, BNode)): 173 if isinstance(rt, (Variable, BNode)):
171 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct: 174 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct:
172 raise Inconsistent(f'{rt=} {ct=} {outBinding=}') 175 msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else ''
176 raise Inconsistent(msg)
173 outBinding.addNewBindings(CandidateBinding({rt: ct})) 177 outBinding.addNewBindings(CandidateBinding({rt: ct}))
174 return outBinding 178 return outBinding
175 179
176 def currentBinding(self) -> CandidateBinding: 180 def currentBinding(self) -> CandidateBinding:
177 if self.pastEnd(): 181 if self.pastEnd():
189 raise NoOptions() 193 raise NoOptions()
190 194
191 195
192 @dataclass 196 @dataclass
193 class Lhs: 197 class Lhs:
194 graph: Graph 198 graph: Graph # our full LHS graph, as input. See below for the statements partitioned into groups.
195 199
196 def __post_init__(self): 200 def __post_init__(self):
197 pass 201
202 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
203
204 stmtsToMatch = list(self.graph - usedByFuncs)
205 self.staticStmts = []
206 self.patternStmts = []
207 for st in stmtsToMatch:
208 if all(isinstance(term, (URIRef, Literal)) for term in st):
209 self.staticStmts.append(st)
210 else:
211 self.patternStmts.append(st)
212
213 # sort them by variable dependencies; don't just try all perms!
214 def lightSortKey(stmt): # Not this.
215 (s, p, o) = stmt
216 return p in rulePredicates(), p, s, o
217
218 self.patternStmts.sort(key=lightSortKey)
219
220 self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
221 self.myPreds -= rulePredicates()
222 self.myPreds -= {RDF.first, RDF.rest}
223 self.myPreds = set(self.myPreds)
198 224
199 def __repr__(self): 225 def __repr__(self):
200 return f"Lhs({graphDump(self.graph)})" 226 return f"Lhs({graphDump(self.graph)})"
201 227
202 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: 228 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
203 """bindings that fit the LHS of a rule, using statements from workingSet and functions 229 """bindings that fit the LHS of a rule, using statements from workingSet and functions
204 from LHS""" 230 from LHS"""
205 if self.graph.__len__() == 0: 231 if self.graph.__len__() == 0:
206 # special case- no LHS! 232 # special case- no LHS!
207 yield BoundLhs(self, CandidateBinding({})) 233 yield BoundLhs(self, CandidateBinding({}))
209 235
210 if self._checkPredicateCounts(knownTrue): 236 if self._checkPredicateCounts(knownTrue):
211 stats['_checkPredicateCountsCulls'] += 1 237 stats['_checkPredicateCountsCulls'] += 1
212 return 238 return
213 239
240 if not all(st in knownTrue for st in self.staticStmts):
241 stats['staticStmtCulls'] += 1
242 return
243
244 if len(self.patternStmts) == 0:
245 # static only
246 yield BoundLhs(self, CandidateBinding({}))
247 return
248
214 log.debug(f'{INDENT*4} build new StmtLooper stack') 249 log.debug(f'{INDENT*4} build new StmtLooper stack')
215 250
216 try: 251 try:
217 stmtStack = self._assembleRings(knownTrue) 252 stmtStack = self._assembleRings(knownTrue, stats)
218 except NoOptions: 253 except NoOptions:
219 log.debug(f'{INDENT*5} start up with no options; 0 bindings') 254 log.debug(f'{INDENT*5} start up with no options; 0 bindings')
220 return 255 return
221 self._debugStmtStack('initial odometer', stmtStack) 256 self._debugStmtStack('initial odometer', stmtStack)
222 self._assertAllRingsAreValid(stmtStack) 257 self._assertAllRingsAreValid(stmtStack)
223 258
224 lastRing = stmtStack[-1] 259 lastRing = stmtStack[-1]
225 iterCount = 0 260 iterCount = 0
226 while True: 261 while True:
227 iterCount += 1 262 iterCount += 1
228 if iterCount > 10: 263 if iterCount > ruleStatementsIterationLimit:
229 raise ValueError('stuck') 264 raise ValueError('rule too complex')
230 265
231 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') 266 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
232 267
233 yield BoundLhs(self, lastRing.currentBinding()) 268 yield BoundLhs(self, lastRing.currentBinding())
234 269
247 for l in stmtStack: 282 for l in stmtStack:
248 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') 283 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
249 284
250 def _checkPredicateCounts(self, knownTrue): 285 def _checkPredicateCounts(self, knownTrue):
251 """raise NoOptions quickly in some cases""" 286 """raise NoOptions quickly in some cases"""
252 myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef)) 287
253 myPreds -= rulePredicates() 288 if any((None, p, None) not in knownTrue for p in self.myPreds):
254 myPreds -= {RDF.first, RDF.rest}
255 if any((None, p, None) not in knownTrue for p in set(myPreds)):
256 return True 289 return True
290 log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
257 return False 291 return False
258 292
259 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: 293 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]:
260 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all 294 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
261 start out valid (or else raise NoOptions)""" 295 start out valid (or else raise NoOptions)"""
262 296
263 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) 297 log.info(f'{INDENT*2} stats={dict(stats)}')
264 298 log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}')
265 stmtsToAdd = list(self.graph - usedByFuncs) 299 for i, perm in enumerate(itertools.permutations(self.patternStmts)):
266
267 # sort them by variable dependencies; don't just try all perms!
268 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases.
269 (s, p, o) = stmt
270 return p == MATH['sum'], p, s, o
271
272 stmtsToAdd.sort(key=lightSortKey)
273
274 for i, perm in enumerate(itertools.permutations(stmtsToAdd)):
275 stmtStack: List[StmtLooper] = [] 300 stmtStack: List[StmtLooper] = []
276 prev: Optional[StmtLooper] = None 301 prev: Optional[StmtLooper] = None
277 log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') 302 if log.isEnabledFor(logging.DEBUG):
303 log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
278 304
279 for s in perm: 305 for s in perm:
280 try: 306 try:
281 elem = StmtLooper(s, prev, knownTrue, parent=self) 307 elem = StmtLooper(s, prev, knownTrue, parent=self)
282 except NoOptions: 308 except NoOptions:
284 break 310 break
285 stmtStack.append(elem) 311 stmtStack.append(elem)
286 prev = stmtStack[-1] 312 prev = stmtStack[-1]
287 else: 313 else:
288 return stmtStack 314 return stmtStack
315 if i > 5000:
316 raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}')
317
289 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything') 318 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything')
290
291 raise NoOptions() 319 raise NoOptions()
292 320
293 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: 321 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
294 carry = True # 1st elem always must advance 322 carry = True # 1st elem always must advance
295 for i, ring in enumerate(stmtStack): 323 for i, ring in enumerate(stmtStack):
327 def __post_init__(self): 355 def __post_init__(self):
328 self.lhs = Lhs(self.lhsGraph) 356 self.lhs = Lhs(self.lhsGraph)
329 # 357 #
330 self.rhsBnodeMap = {} 358 self.rhsBnodeMap = {}
331 359
332 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): 360 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit):
333 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): 361 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit):
334 log.debug(f'{INDENT*5} +rule has a working binding: {bound}') 362 log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
335 363
336 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do 364 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
337 existingRhsBnodes = set() 365 existingRhsBnodes = set()
338 for stmt in self.rhsGraph: 366 for stmt in self.rhsGraph:
360 # log.debug(f'{INDENT*6} adding {newStmt=}') 388 # log.debug(f'{INDENT*6} adding {newStmt=}')
361 workingSet.add(newStmt) 389 workingSet.add(newStmt)
362 implied.add(newStmt) 390 implied.add(newStmt)
363 391
364 392
393 @dataclass
365 class Inference: 394 class Inference:
395 rulesIterationLimit = 3
396 ruleStatementsIterationLimit = 3
366 397
367 def __init__(self) -> None: 398 def __init__(self) -> None:
399 self.rules: List[Rule] = []
400 self._nonRuleStmts: List[Triple] = []
401
402 def setRules(self, g: ConjunctiveGraph):
368 self.rules = [] 403 self.rules = []
369 404 self._nonRuleStmts = []
370 def setRules(self, g: ConjunctiveGraph):
371 self.rules: List[Rule] = []
372 for stmt in g: 405 for stmt in g:
373 if stmt[1] == LOG['implies']: 406 if stmt[1] == LOG['implies']:
374 self.rules.append(Rule(stmt[0], stmt[2])) 407 self.rules.append(Rule(stmt[0], stmt[2]))
375 # other stmts should go to a default working set? 408 # other stmts should go to a default working set?
376 409
389 workingSet += graph 422 workingSet += graph
390 423
391 # just the statements that came from RHS's of rules that fired. 424 # just the statements that came from RHS's of rules that fired.
392 implied = ConjunctiveGraph() 425 implied = ConjunctiveGraph()
393 426
394 bailout_iterations = 100 427 rulesIterations = 0
395 delta = 1 428 delta = 1
396 stats['initWorkingSet'] = cast(int, workingSet.__len__()) 429 stats['initWorkingSet'] = cast(int, workingSet.__len__())
397 while delta > 0 and bailout_iterations > 0: 430 while delta > 0 and rulesIterations <= self.rulesIterationLimit:
398 log.debug('') 431 log.debug('')
399 log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') 432 log.info(f'{INDENT*1}*iteration {rulesIterations}')
400 bailout_iterations -= 1 433
401 delta = -len(implied) 434 delta = -len(implied)
402 self._iterateAllRules(workingSet, implied, stats) 435 self._iterateAllRules(workingSet, implied, stats)
403 delta += len(implied) 436 delta += len(implied)
404 stats['iterations'] += 1 437 rulesIterations += 1
405 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts') 438 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
439 stats['iterations'] = rulesIterations
406 stats['timeSpent'] = round(time.time() - startTime, 3) 440 stats['timeSpent'] = round(time.time() - startTime, 3)
407 stats['impliedStmts'] = len(implied) 441 stats['impliedStmts'] = len(implied)
408 log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:') 442 log.info(f'{INDENT*0} Inference done {dict(stats)}.')
409 log.info(graphDump(implied)) 443 log.debug('Implied:')
444 log.debug(graphDump(implied))
410 return implied 445 return implied
411 446
412 def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats): 447 def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
413 for i, rule in enumerate(self.rules): 448 for i, rule in enumerate(self.rules):
414 self._logRuleApplicationHeader(workingSet, i, rule) 449 self._logRuleApplicationHeader(workingSet, i, rule)
415 rule.applyRule(workingSet, implied, stats) 450 rule.applyRule(workingSet, implied, stats, self.ruleStatementsIterationLimit)
416 451
417 def _logRuleApplicationHeader(self, workingSet, i, r: Rule): 452 def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
418 if not log.isEnabledFor(logging.DEBUG): 453 if not log.isEnabledFor(logging.DEBUG):
419 return 454 return
420 455
421 log.debug('') 456 log.debug('')
422 log.debug(f'{INDENT*2} workingSet:') 457 log.debug(f'{INDENT*2} workingSet:')
423 for j, stmt in enumerate(sorted(workingSet)): 458 # for j, stmt in enumerate(sorted(workingSet)):
424 log.debug(f'{INDENT*3} ({j}) {stmt}') 459 # log.debug(f'{INDENT*3} ({j}) {stmt}')
460 log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}')
425 461
426 log.debug('') 462 log.debug('')
427 log.debug(f'{INDENT*2}-applying rule {i}') 463 log.debug(f'{INDENT*2}-applying rule {i}')
428 log.debug(f'{INDENT*3} rule def lhs:') 464 log.debug(f'{INDENT*3} rule def lhs:')
429 for stmt in sorted(r.lhsGraph, reverse=True): 465 for stmt in sorted(r.lhsGraph, reverse=True):
430 log.debug(f'{INDENT*4} {stmt}') 466 log.debug(f'{INDENT*4} {stmt}')
431 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') 467 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
432 468
433 469
434 def graphDump(g: Union[Graph, List[Triple]]): 470 def graphDump(g: Union[Graph, List[Triple]], oneLine=True):
435 # this is very slow- debug only! 471 # this is very slow- debug only!
436 if not log.isEnabledFor(logging.DEBUG): 472 if not log.isEnabledFor(logging.DEBUG):
437 return "(skipped dump)" 473 return "(skipped dump)"
438 if not isinstance(g, Graph): 474 if not isinstance(g, Graph):
439 g2 = Graph() 475 g2 = Graph()
440 g2 += g 476 g2 += g
441 g = g2 477 g = g2
442 g.bind('', ROOM) 478 g.bind('', ROOM)
443 g.bind('ex', Namespace('http://example.com/')) 479 g.bind('ex', Namespace('http://example.com/'))
444 lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() 480 lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
445 lines = [line.strip() for line in lines if not line.startswith('@prefix')] 481 lines = [line for line in lines if not line.startswith('@prefix')]
482 if oneLine:
483 lines = [line.strip() for line in lines]
446 return ' '.join(lines) 484 return ' '.join(lines)