changeset 1618:48bf62008c82

attempted to rewrite with CandidateTermMatches but it broke
author drewp@bigasterisk.com
date Wed, 08 Sep 2021 18:32:11 -0700
parents e105032b0e3d
children
files service/mqtt_to_rdf/candidate_binding.py service/mqtt_to_rdf/inference.py
diffstat 2 files changed, 96 insertions(+), 29 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/candidate_binding.py	Mon Sep 06 23:26:07 2021 -0700
+++ b/service/mqtt_to_rdf/candidate_binding.py	Wed Sep 08 18:32:11 2021 -0700
@@ -1,5 +1,5 @@
 from dataclasses import dataclass
-from typing import Dict, Iterable, Iterator, Union
+from typing import Dict, Iterable, Iterator, Optional, Union
 
 from prometheus_client import Summary
 from rdflib import BNode, Graph
@@ -16,24 +16,28 @@
         b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items()))
         return f'CandidateBinding({b})'
 
-    def apply(self, g: Union[Graph, Iterable[Triple]]) -> Iterator[Triple]:
+    def apply(self, g: Union[Graph, Iterable[Triple]], returnBoundStatementsOnly=True) -> Iterator[Triple]:
         for stmt in g:
             try:
-                bound = (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2]))
+                bound = (
+                    self._applyTerm(stmt[0], returnBoundStatementsOnly), 
+                    self._applyTerm(stmt[1], returnBoundStatementsOnly), 
+                    self._applyTerm(stmt[2], returnBoundStatementsOnly))
             except BindingUnknown:
                 continue
             yield bound
 
-    def _applyTerm(self, term: Node):
+    def _applyTerm(self, term: Node, failUnbound=True):
         if isinstance(term, (Variable, BNode)):
             if term in self.binding:
                 return self.binding[term]
             else:
-                raise BindingUnknown()
+                if failUnbound:
+                    raise BindingUnknown()
         return term
 
     def addNewBindings(self, newBindings: 'CandidateBinding'):
         for k, v in newBindings.binding.items():
-            if k in self.binding and self.binding[k] != v:
-                raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
+            # if k in self.binding and self.binding[k] != v:
+            #     raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
             self.binding[k] = v
--- a/service/mqtt_to_rdf/inference.py	Mon Sep 06 23:26:07 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Wed Sep 08 18:32:11 2021 -0700
@@ -7,7 +7,7 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass, field
-from typing import Dict, Iterator, List, Set, Tuple, Union, cast
+from typing import Dict, Iterator, List, Literal, Optional, Set, Tuple, Union, cast
 
 from prometheus_client import Summary
 from rdflib import BNode, Graph, Namespace, URIRef
@@ -55,53 +55,64 @@
     def __repr__(self):
         return f"Lhs({graphDump(self.graph)})"
 
-    def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
+    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
         log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}')
         stats['findCandidateBindingsCalls'] += 1
 
-        if not self._allStaticStatementsMatch(workingSet):
+        if not self._allStaticStatementsMatch(knownTrue):
             stats['findCandidateBindingEarlyExits'] += 1
             return
 
-        for binding in self._possibleBindings(workingSet, stats):
+        boundSoFar = CandidateBinding({})
+        for binding in self._possibleBindings(knownTrue, boundSoFar, stats):
             log.debug('')
             log.debug(f'{INDENT*4}*trying {binding.binding}')
 
-            if not binding.verify(workingSet):
+            if not binding.verify(knownTrue):
                 log.debug(f'{INDENT*4} this binding did not verify')
                 stats['permCountFailingVerify'] += 1
                 continue
 
             stats['permCountSucceeding'] += 1
             yield binding
+            boundSoFar.addNewBindings(binding.binding)
 
-    def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool:
+    def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool:
         # bug: see TestSelfFulfillingRule.test3 for a case where this rule's
         # static stmt is matched by a non-static stmt in the rule itself
         for ruleStmt in self.staticRuleStmts:
-            if ruleStmt not in workingSet:
+            if ruleStmt not in knownTrue:
                 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule')
                 return False
         return True
 
-    def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']:
+    def _possibleBindings(self, workingSet, boundSoFar, stats) -> Iterator['BoundLhs']:
         """this yields at least the working bindings, and possibly others"""
-        candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet)
-        for bindRow in self._product(candidateTermMatches):
+        for bindRow in self._product(workingSet, boundSoFar):
             try:
                 yield BoundLhs(self, bindRow)
             except EvaluationFailed:
                 stats['permCountFailingEval'] += 1
 
-    def _product(self, candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Iterator[CandidateBinding]:
-        orderedVars, orderedValueSets = _organize(candidateTermMatches)
+    def _product(self, workingSet, boundSoFar: CandidateBinding) -> Iterator[CandidateBinding]:
+        orderedVars = []
+        for stmt in self.graph:
+            for t in stmt:
+                if isinstance(t, (Variable, BNode)):
+                    orderedVars.append(t)
+        orderedVars = sorted(set(orderedVars))
+
+        orderedValueSets = []
+        for v in orderedVars:
+            orderedValueSets.append(CandidateTermMatches(v, self, workingSet, boundSoFar).results)
         self._logCandidates(orderedVars, orderedValueSets)
         log.debug(f'{INDENT*3} trying all permutations:')
-        if not orderedValueSets:
+        if not orderedVars:
             yield CandidateBinding({})
             return
+
         if not orderedValueSets or not all(orderedValueSets):
             # some var or bnode has no options at all
             return
@@ -111,7 +122,7 @@
         while True:
             for col, curr in enumerate(currentSet):
                 currentSet[col] = next(rings[col])
-                log.debug(repr(currentSet))
+                log.debug(f'{INDENT*4} currentSet: {repr(currentSet)}')
                 yield CandidateBinding(dict(zip(orderedVars, currentSet)))
                 if curr is not starts[col]:
                     break
@@ -124,17 +135,17 @@
         candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
         for lhsStmt in self.graph:
             log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}')
-            for i, trueStmt in enumerate(workingSet):
+            for trueStmt in workingSet:
                 # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}')
 
                 for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt):
                     candidateTermMatches[v].update(vals)
 
-        for trueStmt in itertools.chain(workingSet, self.graph):
-            for b in self.lhsBnodes:
-                for t in [trueStmt[0], trueStmt[2]]:
-                    if isinstance(t, (URIRef, BNode)):
-                        candidateTermMatches[b].add(t)
+        # for trueStmt in itertools.chain(workingSet, self.graph):
+        #     for b in self.lhsBnodes:
+        #         for t in [trueStmt[0], trueStmt[2]]:
+        #             if isinstance(t, (URIRef, BNode)):
+        #                 candidateTermMatches[b].add(t)
         return candidateTermMatches
 
     def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]:
@@ -166,6 +177,57 @@
 
 
 @dataclass
+class CandidateTermMatches:
+    """lazily find the possible matches for this term"""
+    term: BindableTerm
+    lhs: Lhs
+    workingSet: Graph
+    boundSoFar: CandidateBinding
+
+    def __post_init__(self):
+        self.results: List[Node] = []  # we have to be able to repeat the results
+
+        res: Set[Node] = set()
+        for trueStmt in self.workingSet:  # all bound
+            lStmts = list(self.lhsStmtsContainingTerm())
+            log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}')
+            for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False):
+                log.debug(f'{INDENT*4} {pat=}')
+                implied = self._stmtImplies(pat, trueStmt)
+                if implied is not None:
+                    res.add(implied)
+        self.results = list(res)
+        # self.results.sort()
+
+        log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}')
+
+    def _stmtImplies(self, pat: Triple, trueStmt: Triple) -> Optional[Node]:
+        """what value, if any, do we learn for our term from this LHS pattern statement and this known-true stmt"""
+        r = None
+        for p, t in zip(pat, trueStmt):
+            if isinstance(p, (Variable, BNode)):
+                if p != self.term:
+                    # stmt is unbound in more than just our term
+                    continue  # unsure what to do - err on the side of too many bindings, since they get rechecked later
+                if r is None:
+                    r = t
+                    log.debug(f'{INDENT*4}  implied term value {p=} {t=}')
+                elif r != t:
+                    # (?x c ?x) matched with (a b c) doesn't work
+                    return None
+        return r
+
+    def lhsStmtsContainingTerm(self):
+        # lhs could precompute this
+        for lhsStmt in self.lhs.graph:
+            if self.term in lhsStmt:
+                yield lhsStmt
+
+    def __iter__(self):
+        return iter(self.results)
+
+
+@dataclass
 class BoundLhs:
     lhs: Lhs
     binding: CandidateBinding
@@ -240,10 +302,10 @@
             log.debug(f'{INDENT*3} rule has a working binding:')
 
             for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()):
-                log.debug(f'{INDENT*5} adding {lhsBoundStmt=}')
+                log.debug(f'{INDENT*4} adding {lhsBoundStmt=}')
                 workingSet.add(lhsBoundStmt)
             for newStmt in bound.binding.apply(self.rhsGraph):
-                log.debug(f'{INDENT*5} adding {newStmt=}')
+                log.debug(f'{INDENT*4} adding {newStmt=}')
                 workingSet.add(newStmt)
                 implied.add(newStmt)
 
@@ -279,6 +341,7 @@
         delta = 1
         stats['initWorkingSet'] = cast(int, workingSet.__len__())
         while delta > 0 and bailout_iterations > 0:
+            log.debug('')
             log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
             bailout_iterations -= 1
             delta = -len(implied)