Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1618:48bf62008c82
attempted to rewrite with CandidateTermMatches but it broke
author | drewp@bigasterisk.com |
---|---|
date | Wed, 08 Sep 2021 18:32:11 -0700 |
parents | 3a6ed545357f |
children |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Mon Sep 06 23:26:07 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Wed Sep 08 18:32:11 2021 -0700 @@ -7,7 +7,7 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from typing import Dict, Iterator, List, Set, Tuple, Union, cast +from typing import Dict, Iterator, List, Literal, Optional, Set, Tuple, Union, cast from prometheus_client import Summary from rdflib import BNode, Graph, Namespace, URIRef @@ -55,53 +55,64 @@ def __repr__(self): return f"Lhs({graphDump(self.graph)})" - def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: + def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: """bindings that fit the LHS of a rule, using statements from workingSet and functions from LHS""" log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') stats['findCandidateBindingsCalls'] += 1 - if not self._allStaticStatementsMatch(workingSet): + if not self._allStaticStatementsMatch(knownTrue): stats['findCandidateBindingEarlyExits'] += 1 return - for binding in self._possibleBindings(workingSet, stats): + boundSoFar = CandidateBinding({}) + for binding in self._possibleBindings(knownTrue, boundSoFar, stats): log.debug('') log.debug(f'{INDENT*4}*trying {binding.binding}') - if not binding.verify(workingSet): + if not binding.verify(knownTrue): log.debug(f'{INDENT*4} this binding did not verify') stats['permCountFailingVerify'] += 1 continue stats['permCountSucceeding'] += 1 yield binding + boundSoFar.addNewBindings(binding.binding) - def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool: + def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: # bug: see TestSelfFulfillingRule.test3 for a case where this rule's # static stmt is matched by a non-static stmt in the rule itself for ruleStmt in self.staticRuleStmts: - if ruleStmt not in workingSet: + if ruleStmt not in knownTrue: log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') return False return True - def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']: + def _possibleBindings(self, workingSet, boundSoFar, stats) -> Iterator['BoundLhs']: """this yields at least the working bindings, and possibly others""" - candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet) - for bindRow in self._product(candidateTermMatches): + for bindRow in self._product(workingSet, boundSoFar): try: yield BoundLhs(self, bindRow) except EvaluationFailed: stats['permCountFailingEval'] += 1 - def _product(self, candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Iterator[CandidateBinding]: - orderedVars, orderedValueSets = _organize(candidateTermMatches) + def _product(self, workingSet, boundSoFar: CandidateBinding) -> Iterator[CandidateBinding]: + orderedVars = [] + for stmt in self.graph: + for t in stmt: + if isinstance(t, (Variable, BNode)): + orderedVars.append(t) + orderedVars = sorted(set(orderedVars)) + + orderedValueSets = [] + for v in orderedVars: + orderedValueSets.append(CandidateTermMatches(v, self, workingSet, boundSoFar).results) self._logCandidates(orderedVars, orderedValueSets) log.debug(f'{INDENT*3} trying all permutations:') - if not orderedValueSets: + if not orderedVars: yield CandidateBinding({}) return + if not orderedValueSets or not all(orderedValueSets): # some var or bnode has no options at all return @@ -111,7 +122,7 @@ while True: for col, curr in enumerate(currentSet): currentSet[col] = next(rings[col]) - log.debug(repr(currentSet)) + log.debug(f'{INDENT*4} currentSet: {repr(currentSet)}') yield CandidateBinding(dict(zip(orderedVars, currentSet))) if curr is not starts[col]: break @@ -124,17 +135,17 @@ candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) for lhsStmt in self.graph: log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}') - for i, trueStmt in enumerate(workingSet): + for trueStmt in workingSet: # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}') for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt): candidateTermMatches[v].update(vals) - for trueStmt in itertools.chain(workingSet, self.graph): - for b in self.lhsBnodes: - for t in [trueStmt[0], trueStmt[2]]: - if isinstance(t, (URIRef, BNode)): - candidateTermMatches[b].add(t) + # for trueStmt in itertools.chain(workingSet, self.graph): + # for b in self.lhsBnodes: + # for t in [trueStmt[0], trueStmt[2]]: + # if isinstance(t, (URIRef, BNode)): + # candidateTermMatches[b].add(t) return candidateTermMatches def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]: @@ -166,6 +177,57 @@ @dataclass +class CandidateTermMatches: + """lazily find the possible matches for this term""" + term: BindableTerm + lhs: Lhs + workingSet: Graph + boundSoFar: CandidateBinding + + def __post_init__(self): + self.results: List[Node] = [] # we have to be able to repeat the results + + res: Set[Node] = set() + for trueStmt in self.workingSet: # all bound + lStmts = list(self.lhsStmtsContainingTerm()) + log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}') + for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False): + log.debug(f'{INDENT*4} {pat=}') + implied = self._stmtImplies(pat, trueStmt) + if implied is not None: + res.add(implied) + self.results = list(res) + # self.results.sort() + + log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}') + + def _stmtImplies(self, pat: Triple, trueStmt: Triple) -> Optional[Node]: + """what value, if any, do we learn for our term from this LHS pattern statement and this known-true stmt""" + r = None + for p, t in zip(pat, trueStmt): + if isinstance(p, (Variable, BNode)): + if p != self.term: + # stmt is unbound in more than just our term + continue # unsure what to do - err on the side of too many bindings, since they get rechecked later + if r is None: + r = t + log.debug(f'{INDENT*4} implied term value {p=} {t=}') + elif r != t: + # (?x c ?x) matched with (a b c) doesn't work + return None + return r + + def lhsStmtsContainingTerm(self): + # lhs could precompute this + for lhsStmt in self.lhs.graph: + if self.term in lhsStmt: + yield lhsStmt + + def __iter__(self): + return iter(self.results) + + +@dataclass class BoundLhs: lhs: Lhs binding: CandidateBinding @@ -240,10 +302,10 @@ log.debug(f'{INDENT*3} rule has a working binding:') for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): - log.debug(f'{INDENT*5} adding {lhsBoundStmt=}') + log.debug(f'{INDENT*4} adding {lhsBoundStmt=}') workingSet.add(lhsBoundStmt) for newStmt in bound.binding.apply(self.rhsGraph): - log.debug(f'{INDENT*5} adding {newStmt=}') + log.debug(f'{INDENT*4} adding {newStmt=}') workingSet.add(newStmt) implied.add(newStmt) @@ -279,6 +341,7 @@ delta = 1 stats['initWorkingSet'] = cast(int, workingSet.__len__()) while delta > 0 and bailout_iterations > 0: + log.debug('') log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') bailout_iterations -= 1 delta = -len(implied)