Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1648:3059f31b2dfa
more performance work
author | drewp@bigasterisk.com |
---|---|
date | Fri, 17 Sep 2021 11:10:18 -0700 |
parents | 5403c6343fa4 |
children | bb5d2b5370ac |
comparison
equal
deleted
inserted
replaced
1647:34eb87f68ab8 | 1648:3059f31b2dfa |
---|---|
59 prev: Optional['StmtLooper'] | 59 prev: Optional['StmtLooper'] |
60 workingSet: ReadOnlyWorkingSet | 60 workingSet: ReadOnlyWorkingSet |
61 parent: 'Lhs' # just for lhs.graph, really | 61 parent: 'Lhs' # just for lhs.graph, really |
62 | 62 |
63 def __repr__(self): | 63 def __repr__(self): |
64 return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' | 64 return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})' |
65 | 65 |
66 def __post_init__(self): | 66 def __post_init__(self): |
67 self._shortId = next(_stmtLooperShortId) | 67 self._shortId = next(_stmtLooperShortId) |
68 self._myWorkingSetMatches = self._myMatches(self.workingSet) | 68 self._myWorkingSetMatches = self._myMatches(self.workingSet) |
69 | 69 |
70 self._current = CandidateBinding({}) | 70 self._current = CandidateBinding({}) |
71 self._pastEnd = False | 71 self._pastEnd = False |
72 self._seenBindings: List[CandidateBinding] = [] | 72 self._seenBindings: List[CandidateBinding] = [] |
73 | |
74 log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})') | |
75 | |
73 self.restart() | 76 self.restart() |
74 | 77 |
75 def _myMatches(self, g: Graph) -> List[Triple]: | 78 def _myMatches(self, g: Graph) -> List[Triple]: |
76 template = stmtTemplate(self.lhsStmt) | 79 template = stmtTemplate(self.lhsStmt) |
77 | 80 |
117 | 120 |
118 for stmt in augmentedWorkingSet: | 121 for stmt in augmentedWorkingSet: |
119 try: | 122 try: |
120 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) | 123 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) |
121 except Inconsistent: | 124 except Inconsistent: |
122 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') | 125 log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings') |
123 continue | 126 continue |
124 | 127 |
125 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') | 128 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') |
126 if outBinding not in self._seenBindings: | 129 if outBinding not in self._seenBindings: |
127 self._seenBindings.append(outBinding.copy()) | 130 self._seenBindings.append(outBinding.copy()) |
167 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: | 170 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: |
168 outBinding = self._prevBindings().copy() | 171 outBinding = self._prevBindings().copy() |
169 for rt, ct in zip(self.lhsStmt, newStmt): | 172 for rt, ct in zip(self.lhsStmt, newStmt): |
170 if isinstance(rt, (Variable, BNode)): | 173 if isinstance(rt, (Variable, BNode)): |
171 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct: | 174 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct: |
172 raise Inconsistent(f'{rt=} {ct=} {outBinding=}') | 175 msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else '' |
176 raise Inconsistent(msg) | |
173 outBinding.addNewBindings(CandidateBinding({rt: ct})) | 177 outBinding.addNewBindings(CandidateBinding({rt: ct})) |
174 return outBinding | 178 return outBinding |
175 | 179 |
176 def currentBinding(self) -> CandidateBinding: | 180 def currentBinding(self) -> CandidateBinding: |
177 if self.pastEnd(): | 181 if self.pastEnd(): |
189 raise NoOptions() | 193 raise NoOptions() |
190 | 194 |
191 | 195 |
192 @dataclass | 196 @dataclass |
193 class Lhs: | 197 class Lhs: |
194 graph: Graph | 198 graph: Graph # our full LHS graph, as input. See below for the statements partitioned into groups. |
195 | 199 |
196 def __post_init__(self): | 200 def __post_init__(self): |
197 pass | 201 |
202 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) | |
203 | |
204 stmtsToMatch = list(self.graph - usedByFuncs) | |
205 self.staticStmts = [] | |
206 self.patternStmts = [] | |
207 for st in stmtsToMatch: | |
208 if all(isinstance(term, (URIRef, Literal)) for term in st): | |
209 self.staticStmts.append(st) | |
210 else: | |
211 self.patternStmts.append(st) | |
212 | |
213 # sort them by variable dependencies; don't just try all perms! | |
214 def lightSortKey(stmt): # Not this. | |
215 (s, p, o) = stmt | |
216 return p in rulePredicates(), p, s, o | |
217 | |
218 self.patternStmts.sort(key=lightSortKey) | |
219 | |
220 self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef)) | |
221 self.myPreds -= rulePredicates() | |
222 self.myPreds -= {RDF.first, RDF.rest} | |
223 self.myPreds = set(self.myPreds) | |
198 | 224 |
199 def __repr__(self): | 225 def __repr__(self): |
200 return f"Lhs({graphDump(self.graph)})" | 226 return f"Lhs({graphDump(self.graph)})" |
201 | 227 |
202 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: | 228 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']: |
203 """bindings that fit the LHS of a rule, using statements from workingSet and functions | 229 """bindings that fit the LHS of a rule, using statements from workingSet and functions |
204 from LHS""" | 230 from LHS""" |
205 if self.graph.__len__() == 0: | 231 if self.graph.__len__() == 0: |
206 # special case- no LHS! | 232 # special case- no LHS! |
207 yield BoundLhs(self, CandidateBinding({})) | 233 yield BoundLhs(self, CandidateBinding({})) |
209 | 235 |
210 if self._checkPredicateCounts(knownTrue): | 236 if self._checkPredicateCounts(knownTrue): |
211 stats['_checkPredicateCountsCulls'] += 1 | 237 stats['_checkPredicateCountsCulls'] += 1 |
212 return | 238 return |
213 | 239 |
240 if not all(st in knownTrue for st in self.staticStmts): | |
241 stats['staticStmtCulls'] += 1 | |
242 return | |
243 | |
244 if len(self.patternStmts) == 0: | |
245 # static only | |
246 yield BoundLhs(self, CandidateBinding({})) | |
247 return | |
248 | |
214 log.debug(f'{INDENT*4} build new StmtLooper stack') | 249 log.debug(f'{INDENT*4} build new StmtLooper stack') |
215 | 250 |
216 try: | 251 try: |
217 stmtStack = self._assembleRings(knownTrue) | 252 stmtStack = self._assembleRings(knownTrue, stats) |
218 except NoOptions: | 253 except NoOptions: |
219 log.debug(f'{INDENT*5} start up with no options; 0 bindings') | 254 log.debug(f'{INDENT*5} start up with no options; 0 bindings') |
220 return | 255 return |
221 self._debugStmtStack('initial odometer', stmtStack) | 256 self._debugStmtStack('initial odometer', stmtStack) |
222 self._assertAllRingsAreValid(stmtStack) | 257 self._assertAllRingsAreValid(stmtStack) |
223 | 258 |
224 lastRing = stmtStack[-1] | 259 lastRing = stmtStack[-1] |
225 iterCount = 0 | 260 iterCount = 0 |
226 while True: | 261 while True: |
227 iterCount += 1 | 262 iterCount += 1 |
228 if iterCount > 10: | 263 if iterCount > ruleStatementsIterationLimit: |
229 raise ValueError('stuck') | 264 raise ValueError('rule too complex') |
230 | 265 |
231 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') | 266 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') |
232 | 267 |
233 yield BoundLhs(self, lastRing.currentBinding()) | 268 yield BoundLhs(self, lastRing.currentBinding()) |
234 | 269 |
247 for l in stmtStack: | 282 for l in stmtStack: |
248 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') | 283 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') |
249 | 284 |
250 def _checkPredicateCounts(self, knownTrue): | 285 def _checkPredicateCounts(self, knownTrue): |
251 """raise NoOptions quickly in some cases""" | 286 """raise NoOptions quickly in some cases""" |
252 myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef)) | 287 |
253 myPreds -= rulePredicates() | 288 if any((None, p, None) not in knownTrue for p in self.myPreds): |
254 myPreds -= {RDF.first, RDF.rest} | |
255 if any((None, p, None) not in knownTrue for p in set(myPreds)): | |
256 return True | 289 return True |
290 log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue') | |
257 return False | 291 return False |
258 | 292 |
259 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: | 293 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]: |
260 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all | 294 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all |
261 start out valid (or else raise NoOptions)""" | 295 start out valid (or else raise NoOptions)""" |
262 | 296 |
263 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) | 297 log.info(f'{INDENT*2} stats={dict(stats)}') |
264 | 298 log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}') |
265 stmtsToAdd = list(self.graph - usedByFuncs) | 299 for i, perm in enumerate(itertools.permutations(self.patternStmts)): |
266 | |
267 # sort them by variable dependencies; don't just try all perms! | |
268 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. | |
269 (s, p, o) = stmt | |
270 return p == MATH['sum'], p, s, o | |
271 | |
272 stmtsToAdd.sort(key=lightSortKey) | |
273 | |
274 for i, perm in enumerate(itertools.permutations(stmtsToAdd)): | |
275 stmtStack: List[StmtLooper] = [] | 300 stmtStack: List[StmtLooper] = [] |
276 prev: Optional[StmtLooper] = None | 301 prev: Optional[StmtLooper] = None |
277 log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') | 302 if log.isEnabledFor(logging.DEBUG): |
303 log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') | |
278 | 304 |
279 for s in perm: | 305 for s in perm: |
280 try: | 306 try: |
281 elem = StmtLooper(s, prev, knownTrue, parent=self) | 307 elem = StmtLooper(s, prev, knownTrue, parent=self) |
282 except NoOptions: | 308 except NoOptions: |
284 break | 310 break |
285 stmtStack.append(elem) | 311 stmtStack.append(elem) |
286 prev = stmtStack[-1] | 312 prev = stmtStack[-1] |
287 else: | 313 else: |
288 return stmtStack | 314 return stmtStack |
315 if i > 5000: | |
316 raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}') | |
317 | |
289 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything') | 318 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything') |
290 | |
291 raise NoOptions() | 319 raise NoOptions() |
292 | 320 |
293 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: | 321 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: |
294 carry = True # 1st elem always must advance | 322 carry = True # 1st elem always must advance |
295 for i, ring in enumerate(stmtStack): | 323 for i, ring in enumerate(stmtStack): |
327 def __post_init__(self): | 355 def __post_init__(self): |
328 self.lhs = Lhs(self.lhsGraph) | 356 self.lhs = Lhs(self.lhsGraph) |
329 # | 357 # |
330 self.rhsBnodeMap = {} | 358 self.rhsBnodeMap = {} |
331 | 359 |
332 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): | 360 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit): |
333 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): | 361 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit): |
334 log.debug(f'{INDENT*5} +rule has a working binding: {bound}') | 362 log.debug(f'{INDENT*5} +rule has a working binding: {bound}') |
335 | 363 |
336 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do | 364 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do |
337 existingRhsBnodes = set() | 365 existingRhsBnodes = set() |
338 for stmt in self.rhsGraph: | 366 for stmt in self.rhsGraph: |
360 # log.debug(f'{INDENT*6} adding {newStmt=}') | 388 # log.debug(f'{INDENT*6} adding {newStmt=}') |
361 workingSet.add(newStmt) | 389 workingSet.add(newStmt) |
362 implied.add(newStmt) | 390 implied.add(newStmt) |
363 | 391 |
364 | 392 |
393 @dataclass | |
365 class Inference: | 394 class Inference: |
395 rulesIterationLimit = 3 | |
396 ruleStatementsIterationLimit = 3 | |
366 | 397 |
367 def __init__(self) -> None: | 398 def __init__(self) -> None: |
399 self.rules: List[Rule] = [] | |
400 self._nonRuleStmts: List[Triple] = [] | |
401 | |
402 def setRules(self, g: ConjunctiveGraph): | |
368 self.rules = [] | 403 self.rules = [] |
369 | 404 self._nonRuleStmts = [] |
370 def setRules(self, g: ConjunctiveGraph): | |
371 self.rules: List[Rule] = [] | |
372 for stmt in g: | 405 for stmt in g: |
373 if stmt[1] == LOG['implies']: | 406 if stmt[1] == LOG['implies']: |
374 self.rules.append(Rule(stmt[0], stmt[2])) | 407 self.rules.append(Rule(stmt[0], stmt[2])) |
375 # other stmts should go to a default working set? | 408 # other stmts should go to a default working set? |
376 | 409 |
389 workingSet += graph | 422 workingSet += graph |
390 | 423 |
391 # just the statements that came from RHS's of rules that fired. | 424 # just the statements that came from RHS's of rules that fired. |
392 implied = ConjunctiveGraph() | 425 implied = ConjunctiveGraph() |
393 | 426 |
394 bailout_iterations = 100 | 427 rulesIterations = 0 |
395 delta = 1 | 428 delta = 1 |
396 stats['initWorkingSet'] = cast(int, workingSet.__len__()) | 429 stats['initWorkingSet'] = cast(int, workingSet.__len__()) |
397 while delta > 0 and bailout_iterations > 0: | 430 while delta > 0 and rulesIterations <= self.rulesIterationLimit: |
398 log.debug('') | 431 log.debug('') |
399 log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') | 432 log.info(f'{INDENT*1}*iteration {rulesIterations}') |
400 bailout_iterations -= 1 | 433 |
401 delta = -len(implied) | 434 delta = -len(implied) |
402 self._iterateAllRules(workingSet, implied, stats) | 435 self._iterateAllRules(workingSet, implied, stats) |
403 delta += len(implied) | 436 delta += len(implied) |
404 stats['iterations'] += 1 | 437 rulesIterations += 1 |
405 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts') | 438 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts') |
439 stats['iterations'] = rulesIterations | |
406 stats['timeSpent'] = round(time.time() - startTime, 3) | 440 stats['timeSpent'] = round(time.time() - startTime, 3) |
407 stats['impliedStmts'] = len(implied) | 441 stats['impliedStmts'] = len(implied) |
408 log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:') | 442 log.info(f'{INDENT*0} Inference done {dict(stats)}.') |
409 log.info(graphDump(implied)) | 443 log.debug('Implied:') |
444 log.debug(graphDump(implied)) | |
410 return implied | 445 return implied |
411 | 446 |
412 def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats): | 447 def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats): |
413 for i, rule in enumerate(self.rules): | 448 for i, rule in enumerate(self.rules): |
414 self._logRuleApplicationHeader(workingSet, i, rule) | 449 self._logRuleApplicationHeader(workingSet, i, rule) |
415 rule.applyRule(workingSet, implied, stats) | 450 rule.applyRule(workingSet, implied, stats, self.ruleStatementsIterationLimit) |
416 | 451 |
417 def _logRuleApplicationHeader(self, workingSet, i, r: Rule): | 452 def _logRuleApplicationHeader(self, workingSet, i, r: Rule): |
418 if not log.isEnabledFor(logging.DEBUG): | 453 if not log.isEnabledFor(logging.DEBUG): |
419 return | 454 return |
420 | 455 |
421 log.debug('') | 456 log.debug('') |
422 log.debug(f'{INDENT*2} workingSet:') | 457 log.debug(f'{INDENT*2} workingSet:') |
423 for j, stmt in enumerate(sorted(workingSet)): | 458 # for j, stmt in enumerate(sorted(workingSet)): |
424 log.debug(f'{INDENT*3} ({j}) {stmt}') | 459 # log.debug(f'{INDENT*3} ({j}) {stmt}') |
460 log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}') | |
425 | 461 |
426 log.debug('') | 462 log.debug('') |
427 log.debug(f'{INDENT*2}-applying rule {i}') | 463 log.debug(f'{INDENT*2}-applying rule {i}') |
428 log.debug(f'{INDENT*3} rule def lhs:') | 464 log.debug(f'{INDENT*3} rule def lhs:') |
429 for stmt in sorted(r.lhsGraph, reverse=True): | 465 for stmt in sorted(r.lhsGraph, reverse=True): |
430 log.debug(f'{INDENT*4} {stmt}') | 466 log.debug(f'{INDENT*4} {stmt}') |
431 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') | 467 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') |
432 | 468 |
433 | 469 |
434 def graphDump(g: Union[Graph, List[Triple]]): | 470 def graphDump(g: Union[Graph, List[Triple]], oneLine=True): |
435 # this is very slow- debug only! | 471 # this is very slow- debug only! |
436 if not log.isEnabledFor(logging.DEBUG): | 472 if not log.isEnabledFor(logging.DEBUG): |
437 return "(skipped dump)" | 473 return "(skipped dump)" |
438 if not isinstance(g, Graph): | 474 if not isinstance(g, Graph): |
439 g2 = Graph() | 475 g2 = Graph() |
440 g2 += g | 476 g2 += g |
441 g = g2 | 477 g = g2 |
442 g.bind('', ROOM) | 478 g.bind('', ROOM) |
443 g.bind('ex', Namespace('http://example.com/')) | 479 g.bind('ex', Namespace('http://example.com/')) |
444 lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() | 480 lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() |
445 lines = [line.strip() for line in lines if not line.startswith('@prefix')] | 481 lines = [line for line in lines if not line.startswith('@prefix')] |
482 if oneLine: | |
483 lines = [line.strip() for line in lines] | |
446 return ' '.join(lines) | 484 return ' '.join(lines) |