Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1636:3252bdc284bc
rm dead code from previous tries
author | drewp@bigasterisk.com |
---|---|
date | Mon, 13 Sep 2021 00:18:47 -0700 |
parents | 22d481f0a924 |
children | ec3f98d0c1d8 |
comparison
equal
deleted
inserted
replaced
1635:22d481f0a924 | 1636:3252bdc284bc |
---|---|
13 from rdflib import RDF, BNode, Graph, Namespace | 13 from rdflib import RDF, BNode, Graph, Namespace |
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
15 from rdflib.term import Literal, Node, Variable | 15 from rdflib.term import Literal, Node, Variable |
16 | 16 |
17 from candidate_binding import CandidateBinding | 17 from candidate_binding import CandidateBinding |
18 from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple) | 18 from inference_types import (BindableTerm, BindingUnknown, ReadOnlyWorkingSet, Triple) |
19 from lhs_evaluation import Decimal, Evaluation, numericNode, parseList | 19 from lhs_evaluation import Decimal, numericNode, parseList |
20 | 20 |
21 log = logging.getLogger('infer') | 21 log = logging.getLogger('infer') |
22 INDENT = ' ' | 22 INDENT = ' ' |
23 | 23 |
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') | 24 INFER_CALLS = Summary('inference_infer_calls', 'calls') |
231 def currentBinding(self) -> CandidateBinding: | 231 def currentBinding(self) -> CandidateBinding: |
232 if self.pastEnd(): | 232 if self.pastEnd(): |
233 raise NotImplementedError() | 233 raise NotImplementedError() |
234 return self._current | 234 return self._current |
235 | 235 |
236 def newLhsStmts(self) -> List[Triple]: | |
237 """under the curent bindings, what new stmts beyond workingSet are also true? includes all `prev`""" | |
238 return [] | |
239 | |
240 def pastEnd(self) -> bool: | 236 def pastEnd(self) -> bool: |
241 return self._pastEnd | 237 return self._pastEnd |
242 | 238 |
243 def restart(self): | 239 def restart(self): |
244 self._pastEnd = False | 240 self._pastEnd = False |
251 @dataclass | 247 @dataclass |
252 class Lhs: | 248 class Lhs: |
253 graph: Graph | 249 graph: Graph |
254 | 250 |
255 def __post_init__(self): | 251 def __post_init__(self): |
256 # do precomputation in here that's not specific to the workingSet | 252 pass |
257 # self.staticRuleStmts = Graph() | |
258 # self.nonStaticRuleStmts = Graph() | |
259 | |
260 # self.lhsBindables: Set[BindableTerm] = set() | |
261 # self.lhsBnodes: Set[BNode] = set() | |
262 # for ruleStmt in self.graph: | |
263 # varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] | |
264 # self.lhsBindables.update(varsAndBnodesInStmt) | |
265 # self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) | |
266 # if not varsAndBnodesInStmt: | |
267 # self.staticRuleStmts.add(ruleStmt) | |
268 # else: | |
269 # self.nonStaticRuleStmts.add(ruleStmt) | |
270 | |
271 # self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts) | |
272 | |
273 self.evaluations = list(Evaluation.findEvals(self.graph)) | |
274 | 253 |
275 def __repr__(self): | 254 def __repr__(self): |
276 return f"Lhs({graphDump(self.graph)})" | 255 return f"Lhs({graphDump(self.graph)})" |
277 | 256 |
278 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: | 257 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: |
378 def _assertAllRingsAreValid(self, stmtStack): | 357 def _assertAllRingsAreValid(self, stmtStack): |
379 if any(ring.pastEnd() for ring in stmtStack): # this is an unexpected debug assertion | 358 if any(ring.pastEnd() for ring in stmtStack): # this is an unexpected debug assertion |
380 log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}') | 359 log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}') |
381 raise NoOptions() | 360 raise NoOptions() |
382 | 361 |
383 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: | |
384 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's | |
385 # static stmt is matched by a non-static stmt in the rule itself | |
386 for ruleStmt in self.staticRuleStmts: | |
387 if ruleStmt not in knownTrue: | |
388 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') | |
389 return False | |
390 return True | |
391 | |
392 def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']: | |
393 """this yields at least the working bindings, and possibly others""" | |
394 candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet) | |
395 for bindRow in self._product(candidateTermMatches): | |
396 try: | |
397 yield BoundLhs(self, bindRow) | |
398 except EvaluationFailed: | |
399 stats['permCountFailingEval'] += 1 | |
400 | |
401 def _allCandidateTermMatches(self, workingSet: ReadOnlyWorkingSet) -> Dict[BindableTerm, Set[Node]]: | |
402 """the total set of terms each variable could possibly match""" | |
403 | |
404 candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) | |
405 for lhsStmt in self.graph: | |
406 log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}') | |
407 for i, trueStmt in enumerate(workingSet): | |
408 # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}') | |
409 | |
410 for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt): | |
411 candidateTermMatches[v].update(vals) | |
412 | |
413 return candidateTermMatches | |
414 | |
415 def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]: | |
416 """if these stmts match otherwise, what BNode or Variable mappings do we learn? | |
417 | |
418 e.g. stmt1=(?x B ?y) and stmt2=(A B C), then we yield (?x, {A}) and (?y, {C}) | |
419 or stmt1=(_:x B C) and stmt2=(A B C), then we yield (_:x, {A}) | |
420 or stmt1=(?x B C) and stmt2=(A B D), then we yield nothing | |
421 """ | |
422 bindingsFromStatement = {} | |
423 for term1, term2 in zip(stmt1, stmt2): | |
424 if isinstance(term1, (BNode, Variable)): | |
425 bindingsFromStatement.setdefault(term1, set()).add(term2) | |
426 elif term1 != term2: | |
427 break | |
428 else: | |
429 for v, vals in bindingsFromStatement.items(): | |
430 log.debug(f'{INDENT*5} {v=} {vals=}') | |
431 yield v, vals | |
432 | |
433 def _product(self, candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Iterator[CandidateBinding]: | |
434 orderedVars, orderedValueSets = _organize(candidateTermMatches) | |
435 self._logCandidates(orderedVars, orderedValueSets) | |
436 log.debug(f'{INDENT*3} trying all permutations:') | |
437 if not orderedValueSets: | |
438 yield CandidateBinding({}) | |
439 return | |
440 | |
441 if not orderedValueSets or not all(orderedValueSets): | |
442 # some var or bnode has no options at all | |
443 return | |
444 rings: List[Iterator[Node]] = [itertools.cycle(valSet) for valSet in orderedValueSets] | |
445 currentSet: List[Node] = [next(ring) for ring in rings] | |
446 starts = [valSet[-1] for valSet in orderedValueSets] | |
447 while True: | |
448 for col, curr in enumerate(currentSet): | |
449 currentSet[col] = next(rings[col]) | |
450 log.debug(f'{INDENT*4} currentSet: {repr(currentSet)}') | |
451 yield CandidateBinding(dict(zip(orderedVars, currentSet))) | |
452 if curr is not starts[col]: | |
453 break | |
454 if col == len(orderedValueSets) - 1: | |
455 return | |
456 | |
457 def _logCandidates(self, orderedVars, orderedValueSets): | |
458 if not log.isEnabledFor(logging.DEBUG): | |
459 return | |
460 log.debug(f'{INDENT*3} resulting candidate terms:') | |
461 for v, vals in zip(orderedVars, orderedValueSets): | |
462 log.debug(f'{INDENT*4} {v!r} could be:') | |
463 for val in vals: | |
464 log.debug(f'{INDENT*5}{val!r}') | |
465 | |
466 | 362 |
467 @dataclass | 363 @dataclass |
468 class BoundLhs: | 364 class BoundLhs: |
469 lhs: Lhs | 365 lhs: Lhs |
470 binding: CandidateBinding | 366 binding: CandidateBinding |
471 | |
472 def __post_init__(self): | |
473 self.usedByFuncs = Graph() | |
474 # self._applyFunctions() | |
475 | |
476 def lhsStmtsWithoutEvals(self): | |
477 for stmt in self.lhs.graph: | |
478 if stmt in self.usedByFuncs: | |
479 continue | |
480 yield stmt | |
481 | |
482 def _applyFunctions(self): | |
483 """may grow the binding with some results""" | |
484 while True: | |
485 delta = self._applyFunctionsIteration() | |
486 if delta == 0: | |
487 break | |
488 | |
489 def _applyFunctionsIteration(self): | |
490 before = len(self.binding.binding) | |
491 delta = 0 | |
492 for ev in self.lhs.evaluations: | |
493 newBindings, usedGraph = ev.resultBindings(self.binding) | |
494 self.usedByFuncs += usedGraph | |
495 self.binding.addNewBindings(newBindings) | |
496 delta = len(self.binding.binding) - before | |
497 log.debug(f'{INDENT*4} eval rules made {delta} new bindings') | |
498 return delta | |
499 | |
500 def verify(self, workingSet: ReadOnlyWorkingSet) -> bool: | |
501 """Can this bound lhs be true all at once in workingSet?""" | |
502 rem = cast(Set[Triple], self.lhs.nonStaticRuleStmtsSet.difference(self.usedByFuncs)) | |
503 boundLhs = self.binding.apply(rem) | |
504 | |
505 if log.isEnabledFor(logging.DEBUG): | |
506 boundLhs = list(boundLhs) | |
507 self._logVerifyBanner(boundLhs, workingSet) | |
508 | |
509 for stmt in boundLhs: | |
510 log.debug(f'{INDENT*4} check for %s', stmt) | |
511 | |
512 if stmt not in workingSet: | |
513 log.debug(f'{INDENT*5} stmt not known to be true') | |
514 return False | |
515 return True | |
516 | |
517 def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet): | |
518 log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:') | |
519 for stmt in sorted(boundLhs): | |
520 log.debug(f'{INDENT*4}|{INDENT} {stmt}') | |
521 | |
522 # log.debug(f'{INDENT*4}| and against this workingSet:') | |
523 # for stmt in sorted(workingSet): | |
524 # log.debug(f'{INDENT*4}|{INDENT} {stmt}') | |
525 | |
526 log.debug(f'{INDENT*4}\\') | |
527 | 367 |
528 | 368 |
529 @dataclass | 369 @dataclass |
530 class Rule: | 370 class Rule: |
531 lhsGraph: Graph | 371 lhsGraph: Graph |
645 g.bind('', ROOM) | 485 g.bind('', ROOM) |
646 g.bind('ex', Namespace('http://example.com/')) | 486 g.bind('ex', Namespace('http://example.com/')) |
647 lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() | 487 lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() |
648 lines = [line.strip() for line in lines if not line.startswith('@prefix')] | 488 lines = [line.strip() for line in lines if not line.startswith('@prefix')] |
649 return ' '.join(lines) | 489 return ' '.join(lines) |
650 | |
651 | |
652 def _organize(candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Tuple[List[BindableTerm], List[List[Node]]]: | |
653 items = list(candidateTermMatches.items()) | |
654 items.sort() | |
655 orderedVars: List[BindableTerm] = [] | |
656 orderedValueSets: List[List[Node]] = [] | |
657 for v, vals in items: | |
658 orderedVars.append(v) | |
659 orderedValues: List[Node] = list(vals) | |
660 orderedValues.sort(key=str) | |
661 orderedValueSets.append(orderedValues) | |
662 | |
663 return orderedVars, orderedValueSets |