diff service/mqtt_to_rdf/inference.py @ 1631:2c85a4f5dd9c

big rewrite of infer() using statements not variables as the things to iterate over
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 04:32:52 -0700
parents ea559a846714
children bd79a2941cab
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sat Sep 11 23:33:55 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sun Sep 12 04:32:52 2021 -0700
@@ -7,16 +7,16 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import Dict, Iterator, List, Set, Tuple, Union, cast
+from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast
 
-from prometheus_client import Summary, Histogram
-from rdflib import BNode, Graph, Namespace, URIRef
+from prometheus_client import Histogram, Summary
+from rdflib import BNode, Graph, Namespace
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
-from rdflib.term import Node, Variable
+from rdflib.term import Literal, Node, Variable
 
 from candidate_binding import CandidateBinding
 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
-from lhs_evaluation import Evaluation
+from lhs_evaluation import Decimal, Evaluation, numericNode
 
 log = logging.getLogger('infer')
 INDENT = '    '
@@ -29,27 +29,141 @@
 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
 
 
+def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]:
+    return (
+        None if isinstance(stmt[0], (Variable, BNode)) else stmt[0],
+        None if isinstance(stmt[1], (Variable, BNode)) else stmt[1],
+        None if isinstance(stmt[2], (Variable, BNode)) else stmt[2],
+    )
+
+
+class NoOptions(ValueError):
+    """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
+
+
+class Inconsistent(ValueError):
+    """adding this stmt would be inconsistent with an existing binding"""
+
+
+@dataclass
+class StmtLooper:
+    lhsStmt: Triple
+    prev: Optional['StmtLooper']
+    workingSet: ReadOnlyWorkingSet
+
+    def __repr__(self):
+        return f'StmtLooper({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
+
+    def __post_init__(self):
+        self._myWorkingSetMatches = self._myMatches(self.workingSet)
+
+        self._current = CandidateBinding({})
+        self._pastEnd = False
+        self._seenBindings: List[Dict[BindableTerm, Node]] = []
+        self.restart()
+
+    def _myMatches(self, g: Graph) -> List[Triple]:
+        template = stmtTemplate(self.lhsStmt)
+
+        stmts = sorted(cast(Iterator[Triple], list(g.triples(template))))
+        # plus new lhs possibilties...
+        # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}')
+
+        return stmts
+
+    def _prevBindings(self) -> Dict[BindableTerm, Node]:
+        if not self.prev or self.prev.pastEnd():
+            return {}
+
+        return self.prev.currentBinding().binding
+
+    def advance(self):
+        """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode"""
+        log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements')
+        for i, stmt in enumerate(self._myWorkingSetMatches):
+            try:
+                outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
+            except Inconsistent:
+                log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
+                continue
+            log.debug(f'seen {outBinding.binding} in {self._seenBindings}')
+            if outBinding.binding not in self._seenBindings:
+                self._seenBindings.append(outBinding.binding.copy())
+                log.debug(f'no, adding')
+                self._current = outBinding
+                log.debug(f'{INDENT*7} {self} - Looper matches {stmt} which tells us {outBinding}')
+                return
+            log.debug(f'yes we saw')
+
+        log.debug(f'{INDENT*6} {self} mines rules')
+
+        if self.lhsStmt[1] == ROOM['asFarenheit']:
+            pb: Dict[BindableTerm, Node] = self._prevBindings()
+            if self.lhsStmt[0] in pb:
+                operands = [pb[cast(BindableTerm, self.lhsStmt[0])]]
+                f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32))
+                objVar = self.lhsStmt[2]
+                if not isinstance(objVar, Variable):
+                    raise TypeError(f'expected Variable, got {objVar!r}')
+                newBindings = {cast(BindableTerm, objVar): cast(Node, f)}
+                self._current.addNewBindings(CandidateBinding(newBindings))
+                if newBindings not in self._seenBindings:
+                    self._seenBindings.append(newBindings)
+                    self._current = CandidateBinding(newBindings)
+
+        log.debug(f'{INDENT*6} {self} is past end')
+        self._pastEnd = True
+
+    def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
+        outBinding = self._prevBindings().copy()
+        for rt, ct in zip(self.lhsStmt, newStmt):
+            if isinstance(rt, (Variable, BNode)):
+                if rt in outBinding and outBinding[rt] != ct:
+                    raise Inconsistent()
+                outBinding[rt] = ct
+        return CandidateBinding(outBinding)
+
+    def currentBinding(self) -> CandidateBinding:
+        if self.pastEnd():
+            raise NotImplementedError()
+        return self._current
+
+    def newLhsStmts(self) -> List[Triple]:
+        """under the curent bindings, what new stmts beyond workingSet are also true? includes all `prev`"""
+        return []
+
+    def pastEnd(self) -> bool:
+        return self._pastEnd
+
+    def restart(self):
+        self._pastEnd = False
+        self._seenBindings = []
+        self.advance()
+        if self.pastEnd():
+            raise NoOptions()
+
+
 @dataclass
 class Lhs:
     graph: Graph
 
     def __post_init__(self):
         # do precomputation in here that's not specific to the workingSet
-        self.staticRuleStmts = Graph()
-        self.nonStaticRuleStmts = Graph()
+        # self.staticRuleStmts = Graph()
+        # self.nonStaticRuleStmts = Graph()
 
-        self.lhsBindables: Set[BindableTerm] = set()
-        self.lhsBnodes: Set[BNode] = set()
-        for ruleStmt in self.graph:
-            varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))]
-            self.lhsBindables.update(varsAndBnodesInStmt)
-            self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode))
-            if not varsAndBnodesInStmt:
-                self.staticRuleStmts.add(ruleStmt)
-            else:
-                self.nonStaticRuleStmts.add(ruleStmt)
+        # self.lhsBindables: Set[BindableTerm] = set()
+        # self.lhsBnodes: Set[BNode] = set()
+        # for ruleStmt in self.graph:
+        #     varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))]
+        #     self.lhsBindables.update(varsAndBnodesInStmt)
+        #     self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode))
+        #     if not varsAndBnodesInStmt:
+        #         self.staticRuleStmts.add(ruleStmt)
+        #     else:
+        #         self.nonStaticRuleStmts.add(ruleStmt)
 
-        self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts)
+        # self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts)
 
         self.evaluations = list(Evaluation.findEvals(self.graph))
 
@@ -59,24 +173,69 @@
     def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
-        log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}')
-        stats['findCandidateBindingsCalls'] += 1
+        log.debug(f'{INDENT*4} build new StmtLooper stack')
 
-        if not self._allStaticStatementsMatch(knownTrue):
-            stats['findCandidateBindingEarlyExits'] += 1
+        stmtStack: List[StmtLooper] = []
+        try:
+            prev: Optional[StmtLooper] = None
+            for s in sorted(self.graph):  # order of this matters! :(
+                stmtStack.append(StmtLooper(s, prev, knownTrue))
+                prev = stmtStack[-1]
+        except NoOptions:
+            log.debug(f'{INDENT*5} no options; 0 bindings')
             return
 
-        for binding in self._possibleBindings(knownTrue, stats):
-            log.debug('')
-            log.debug(f'{INDENT*4}*trying {binding.binding}')
+        log.debug(f'{INDENT*5} initial odometer:')
+        for l in stmtStack:
+            log.debug(f'{INDENT*6} {l}')
+
+        if any(ring.pastEnd() for ring in stmtStack):
+            log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}')
+
+            raise NoOptions()
+        sl = stmtStack[-1]
+        iterCount = 0
+        while True:
+            iterCount += 1
+            if iterCount > 10:
+                raise ValueError('stuck')
+
+            log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
+
+            log.debug(f'{INDENT*5} <<<')
+            yield BoundLhs(self, sl.currentBinding())
+            log.debug(f'{INDENT*5} >>>')
+
+            log.debug(f'{INDENT*5} odometer:')
+            for l in stmtStack:
+                log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
 
-            if not binding.verify(knownTrue):
-                log.debug(f'{INDENT*4} this binding did not verify')
-                stats['permCountFailingVerify'] += 1
-                continue
+            done = self._advanceAll(stmtStack)
+
+            log.debug(f'{INDENT*5} odometer after ({done=}):')
+            for l in stmtStack:
+                log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
+
+            log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
+            if done:
+                break
 
-            stats['permCountSucceeding'] += 1
-            yield binding
+    def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
+        carry = True  # 1st elem always must advance
+        for i, ring in enumerate(stmtStack):
+            # unlike normal odometer, advancing any earlier ring could invalidate later ones
+            if carry:
+                log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance')
+                ring.advance()
+                carry = False
+            if ring.pastEnd():
+                if ring is stmtStack[-1]:
+                    log.debug(f'{INDENT*5} advanceAll [{i}] {ring} says we done')
+                    return True
+                log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart')
+                ring.restart()
+                carry = True
+        return False
 
     def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool:
         # bug: see TestSelfFulfillingRule.test3 for a case where this rule's
@@ -162,41 +321,6 @@
                 log.debug(f'{INDENT*5}{val!r}')
 
 
-# @dataclass
-# class CandidateTermMatches:
-#     """lazily find the possible matches for this term"""
-#     terms: List[BindableTerm]
-#     lhs: Lhs
-#     knownTrue: Graph
-#     boundSoFar: CandidateBinding
-
-#     def __post_init__(self):
-#         self.results: List[Node] = []  # we have to be able to repeat the results
-
-#         res: Set[Node] = set()
-#         for trueStmt in self.knownTrue:  # all bound
-#             lStmts = list(self.lhsStmtsContainingTerm())
-#             log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}')
-#             for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False):
-#                 log.debug(f'{INDENT*4} {pat=}')
-#                 implied = self._stmtImplies(pat, trueStmt)
-#                 if implied is not None:
-#                     res.add(implied)
-#         self.results = list(res)
-#         # self.results.sort()
-
-#         log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}')
-
-#     def lhsStmtsContainingTerm(self):
-#         # lhs could precompute this
-#         for lhsStmt in self.lhs.graph:
-#             if self.term in lhsStmt:
-#                 yield lhsStmt
-
-#     def __iter__(self):
-#         return iter(self.results)
-
-
 @dataclass
 class BoundLhs:
     lhs: Lhs
@@ -204,7 +328,7 @@
 
     def __post_init__(self):
         self.usedByFuncs = Graph()
-        self._applyFunctions()
+        # self._applyFunctions()
 
     def lhsStmtsWithoutEvals(self):
         for stmt in self.lhs.graph:
@@ -263,19 +387,40 @@
 class Rule:
     lhsGraph: Graph
     rhsGraph: Graph
-
+    
     def __post_init__(self):
         self.lhs = Lhs(self.lhsGraph)
+        # 
+        self.rhsBnodeMap = {}
 
     def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict):
         for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
-            log.debug(f'{INDENT*3} rule has a working binding:')
+            log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
+
+            # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
+            existingRhsBnodes = set()
+            for stmt in self.rhsGraph:
+                for t in stmt:
+                    if isinstance(t, BNode):
+                        existingRhsBnodes.add(t)
+            # if existingRhsBnodes:
+                # log.debug(f'{INDENT*6} mapping rhs bnodes {existingRhsBnodes} to new ones')
+
+            for b in existingRhsBnodes:
 
-            for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()):
-                log.debug(f'{INDENT*4} adding {lhsBoundStmt=}')
-                workingSet.add(lhsBoundStmt)
+                key = tuple(sorted(bound.binding.binding.items())), b
+                self.rhsBnodeMap.setdefault(key, BNode())
+
+
+                bound.binding.addNewBindings(CandidateBinding({b: self.rhsBnodeMap[key]}))
+
+            # for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()):
+            #     log.debug(f'{INDENT*6} adding to workingSet {lhsBoundStmt=}')
+            #     workingSet.add(lhsBoundStmt)
+            # log.debug(f'{INDENT*6} rhsGraph is good: {list(self.rhsGraph)}')
+
             for newStmt in bound.binding.apply(self.rhsGraph):
-                log.debug(f'{INDENT*4} adding {newStmt=}')
+                # log.debug(f'{INDENT*6} adding {newStmt=}')
                 workingSet.add(newStmt)
                 implied.add(newStmt)
 
@@ -350,7 +495,6 @@
 
 def graphDump(g: Union[Graph, List[Triple]]):
     if not isinstance(g, Graph):
-        log.warning(f"it's a {type(g)}")
         g2 = Graph()
         g2 += g
         g = g2