diff service/mqtt_to_rdf/inference.py @ 1593:b0df43d5494c

big rewrite- more classes, smaller methods, more typesafe, all current tests passing
author drewp@bigasterisk.com
date Sat, 04 Sep 2021 23:23:55 -0700
parents d7b66234064b
children e58bcfa66093
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sat Sep 04 23:18:44 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sat Sep 04 23:23:55 2021 -0700
@@ -17,6 +17,7 @@
 from rdflib.term import Node, Variable
 
 log = logging.getLogger('infer')
+INDENT = '    '
 
 Triple = Tuple[Node, Node, Node]
 Rule = Tuple[Graph, Node, Graph]
@@ -35,167 +36,277 @@
     vars: Dict[Variable, Node]
 
 
-inferredFuncs = {
-    ROOM['asFarenheit'],
-    MATH['sum'],
-}
+ReadOnlyWorkingSet = ReadOnlyGraphAggregate
+
 filterFuncs = {
     MATH['greaterThan'],
 }
 
 
-def withBinding(toBind: Graph, bindings: Dict[BindableTerm, Node], includeStaticStmts=True) -> Iterator[Triple]:
-    for stmt in toBind:
-        stmt = list(stmt)
-        static = True
-        for i, term in enumerate(stmt):
-            if isinstance(term, (Variable, BNode)):
-                stmt[i] = bindings[term]
-                static = False
-        else:
-            if includeStaticStmts or not static:
+class CandidateBinding:
+
+    def __init__(self, binding: Dict[BindableTerm, Node]):
+        self.binding = binding  # mutable!
+
+    def __repr__(self):
+        b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items()))
+        return f'CandidateBinding({b})'
+
+    def apply(self, g: Graph) -> Iterator[Triple]:
+        for stmt in g:
+            stmt = list(stmt)
+            for i, term in enumerate(stmt):
+                if isinstance(term, (Variable, BNode)):
+                    if term in self.binding:
+                        stmt[i] = self.binding[term]
+            else:
                 yield cast(Triple, stmt)
 
+    def applyFunctions(self, lhs):
+        """may grow the binding with some results"""
+        usedByFuncs = Graph()
+        while True:
+            before = len(self.binding)
+            delta = 0
+            for ev in Evaluation.findEvals(lhs):
+                log.debug(f'{INDENT*3} found Evaluation')
 
-def verifyBinding(lhs: Graph, binding: Dict[BindableTerm, Node], workingSet: Graph, usedByFuncs: Graph) -> bool:
-    """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
-    boundLhs = list(withBinding(lhs, binding))
-    boundUsedByFuncs = list(withBinding(usedByFuncs, binding))
-    if log.isEnabledFor(logging.DEBUG):
-        log.debug(f'                verify all bindings against this lhs:')
-        for stmt in boundLhs:
-            log.debug(f'                    {stmt}')
+                newBindings, usedGraph = ev.resultBindings(self.binding)
+                usedByFuncs += usedGraph
+                for k, v in newBindings.items():
+                    if k in self.binding and self.binding[k] != v:
+                        raise ValueError(
+                            f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
+                    self.binding[k] = v
+                delta = len(self.binding) - before
+                log.debug(f'{INDENT*4} rule {graphDump(usedGraph)} made {delta} new bindings')
+            if delta == 0:
+                break
+        return usedByFuncs
 
-        log.debug(f'                and against this workingSet:')
-        for stmt in workingSet:
-            log.debug(f'                    {stmt}')
+    def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool:
+        """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
+        boundLhs = list(self.apply(lhs._g))
+        boundUsedByFuncs = list(self.apply(usedByFuncs))
+
+        self.logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs)
 
-        log.debug(f'                ignoring these usedByFuncs:')
-        for stmt in boundUsedByFuncs:
-            log.debug(f'                    {stmt}')
-    # The static stmts in lhs are obviously going
-    # to match- we only need to verify the ones
-    # that needed bindings.
-    for stmt in boundLhs:  #withBinding(lhs, binding, includeStaticStmts=False):
-        log.debug(f'                check for {stmt}')
+        for stmt in boundLhs:
+            log.debug(f'{INDENT*4} check for {stmt}')
 
-        if stmt[1] in filterFuncs:
-            if not mathTest(*stmt):
-                log.debug(f'                    binding was invalid because {stmt}) is not true')
+            if stmt[1] in filterFuncs:
+                if not mathTest(*stmt):
+                    log.debug(f'{INDENT*5} binding was invalid because {stmt}) is not true')
+                    return False
+            elif stmt in boundUsedByFuncs:
+                pass
+            elif stmt in workingSet:
+                pass
+            else:
+                log.debug(f'{INDENT*5} binding was invalid because {stmt}) is not known to be true')
                 return False
-        elif stmt in boundUsedByFuncs:
-            pass
-        elif stmt in workingSet:
-            pass
-        else:
-            log.debug(f'                    binding was invalid because {stmt}) cannot be true')
-            return False
-    return True
+        log.debug(f"{INDENT*5} this rule's lhs can work under this binding")
+        return True
+
+    def logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs):
+        if not log.isEnabledFor(logging.DEBUG):
+            return
+        log.debug(f'{INDENT*4}/ verify all bindings against this lhs:')
+        for stmt in sorted(boundLhs):
+            log.debug(f'{INDENT*4}|{INDENT} {stmt}')
+
+        log.debug(f'{INDENT*4}| and against this workingSet:')
+        for stmt in sorted(workingSet):
+            log.debug(f'{INDENT*4}|{INDENT} {stmt}')
+
+        log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:')
+        for stmt in sorted(boundUsedByFuncs):
+            log.debug(f'{INDENT*4}|{INDENT} {stmt}')
+        log.debug(f'{INDENT*4}\\')
 
 
-def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[BindableTerm, Node]]:
-    """bindings that fit the LHS of a rule, using statements from workingSet and functions
-    from LHS"""
-    varsToBind: Set[BindableTerm] = set()
-    staticRuleStmts = Graph()
-    for ruleStmt in lhs:
-        varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
-        varsToBind.update(varsInStmt)
-        if (not varsInStmt  # ok
-                #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
-           ):
-            staticRuleStmts.add(ruleStmt)
+class Lhs:
+
+    def __init__(self, existingGraph):
+        self._g = existingGraph
+
+    def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]:
+        """bindings that fit the LHS of a rule, using statements from workingSet and functions
+        from LHS"""
+        nodesToBind = self.nodesToBind()
+        log.debug(f'{INDENT*2} nodesToBind: {nodesToBind}')
+
+        if not self.allStaticStatementsMatch(workingSet):
+            return
+
+        candidateTermMatches: Dict[BindableTerm, Set[Node]] = self.allCandidateTermMatches(workingSet)
+
+        # for n in nodesToBind:
+        #     if n not in candidateTermMatches:
+        #         candidateTermMatches[n] = set()
+
+        orderedVars, orderedValueSets = organize(candidateTermMatches)
+
+        self.logCandidates(orderedVars, orderedValueSets)
 
-    log.debug(f'        varsToBind: {sorted(varsToBind)}')
+        log.debug(f'{INDENT*2} trying all permutations:')
+
+        for perm in itertools.product(*orderedValueSets):
+            binding = CandidateBinding(dict(zip(orderedVars, perm)))
+            log.debug('')
+            log.debug(f'{INDENT*3}*trying {binding}')
+
+            usedByFuncs = binding.applyFunctions(self)
 
-    if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
-        log.debug(f'    someStaticStmtDoesntMatch: {graphDump(staticRuleStmts)}')
-        return
+            if not binding.verify(self, workingSet, usedByFuncs):
+                log.debug(f'{INDENT*3} this binding did not verify')
+                continue
+            yield binding
 
-    # the total set of terms each variable could possibly match
-    candidateTermMatches: Dict[BindableTerm, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
+    def nodesToBind(self) -> List[BindableTerm]:
+        nodes: Set[BindableTerm] = set()
+        staticRuleStmts = Graph()
+        for ruleStmt in self._g:
+            varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
+            nodes.update(varsInStmt)
+            if (not varsInStmt  # ok
+                    #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
+               ):
+                staticRuleStmts.add(ruleStmt)
+        return sorted(nodes)
 
-    orderedVars, orderedValueSets = organize(candidateTermMatches)
+    def allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool:
+        staticRuleStmts = Graph()
+        for ruleStmt in self._g:
+            varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
+            if (not varsInStmt  # ok
+                    #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
+               ):
+                staticRuleStmts.add(ruleStmt)
 
-    log.debug(f'        candidate terms:')
-    log.debug(f'            {orderedVars=}')
-    log.debug(f'            {orderedValueSets=}')
+        for ruleStmt in staticRuleStmts:
+            if ruleStmt not in workingSet:
+                log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule')
+                return False
+        return True
+
+    def allCandidateTermMatches(self, workingSet: ReadOnlyWorkingSet) -> Dict[BindableTerm, Set[Node]]:
+        """the total set of terms each variable could possibly match"""
 
-    for i, perm in enumerate(itertools.product(*orderedValueSets)):
-        binding: Dict[BindableTerm, Node] = dict(zip(orderedVars, perm))
-        log.debug('')
-        log.debug(f'            ** trying {binding=}')
-        usedByFuncs = Graph()
-        for v, val, used in inferredFuncBindings(lhs, binding):  # loop this until it's done
-            log.debug(f'            inferredFuncBindings tells us {v}={val}')
-            binding[v] = val
-            usedByFuncs += used
-        if len(binding) != len(varsToBind):
-            log.debug(f'                binding is incomplete, needs {varsToBind}')
+        candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
+        lhsBnodes: Set[BNode] = set()
+        for lhsStmt in self._g:
+            log.debug(f'{INDENT*3} possibles for this lhs stmt: {lhsStmt}')
+            for i, trueStmt in enumerate(sorted(workingSet)):
+                log.debug(f'{INDENT*4} consider this true stmt ({i}): {trueStmt}')
+                bindingsFromStatement: Dict[Variable, Set[Node]] = {}
+                for lhsTerm, trueTerm in zip(lhsStmt, trueStmt):
+                    if isinstance(lhsTerm, BNode):
+                        lhsBnodes.add(lhsTerm)
+                    elif isinstance(lhsTerm, Variable):
+                        bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm)
+                    elif lhsTerm != trueTerm:
+                        break
+                else:
+                    for v, vals in bindingsFromStatement.items():
+                        candidateTermMatches[v].update(vals)
 
-            continue
-        if not verifyBinding(lhs, binding, workingSet, usedByFuncs):  # fix this
-            log.debug(f'            this binding did not verify')
-            continue
-        yield binding
+        for trueStmt in itertools.chain(workingSet, self._g):
+            for b in lhsBnodes:
+                for t in [trueStmt[0], trueStmt[2]]:
+                    if isinstance(t, (URIRef, BNode)):
+                        candidateTermMatches[b].add(t)
+        return candidateTermMatches
+
+    def graphWithoutEvals(self, binding: CandidateBinding) -> Graph:
+        g = Graph()
+        usedByFuncs = binding.applyFunctions(self)
+
+        for stmt in self._g:
+            if stmt not in usedByFuncs:
+                g.add(stmt)
+        return g
+
+    def logCandidates(self, orderedVars, orderedValueSets):
+        if not log.isEnabledFor(logging.DEBUG):
+            return
+        log.debug(f'{INDENT*2} resulting candidate terms:')
+        for v, vals in zip(orderedVars, orderedValueSets):
+            log.debug(f'{INDENT*3} {v} could be:')
+            for val in vals:
+                log.debug(f'{INDENT*4}{val}')
 
 
-def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
-    for ruleStmt in staticRuleStmts:
-        if ruleStmt not in workingSet:
-            log.debug(f'            {ruleStmt} not in working set- skip rule')
+class Evaluation:
+    """some lhs statements need to be evaluated with a special function 
+    (e.g. math) and then not considered for the rest of the rule-firing 
+    process. It's like they already 'matched' something, so they don't need
+    to match a statement from the known-true working set."""
+
+    @staticmethod
+    def findEvals(lhs: Lhs) -> Iterator['Evaluation']:
+        for stmt in lhs._g.triples((None, MATH['sum'], None)):
+            # shouldn't be redoing this here
+            operands, operandsStmts = parseList(lhs._g, stmt[0])
+            g = Graph()
+            g += operandsStmts
+            yield Evaluation(operands, g, stmt)
+
+        for stmt in lhs._g.triples((None, ROOM['asFarenheit'], None)):
+            g = Graph()
+            g.add(stmt)
+            yield Evaluation([stmt[0]], g, stmt)
+
+    # internal, use findEvals
+    def __init__(self, operands: List[Node], operandsStmts: Graph, stmt: Triple) -> None:
+        self.operands = operands
+        self.operandsStmts = operandsStmts
+        self.stmt = stmt
 
-            return True
-    return False
+    def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]:
+        """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?"""
+        pred = self.stmt[1]
+        objVar = self.stmt[2]
+        boundOperands = []
+        for o in self.operands:
+            if isinstance(o, Variable):
+                try:
+                    o = inputBindings[o]
+                except KeyError:
+                    return {}, self.operandsStmts
+
+            boundOperands.append(o)
+
+        if not isinstance(objVar, Variable):
+            raise TypeError(f'expected Variable, got {objVar!r}')
+
+        if pred == MATH['sum']:
+            log.debug(f'{INDENT*4} sum {list(map(self.numericNode, boundOperands))}')
+            obj = cast(Literal, Literal(sum(map(self.numericNode, boundOperands))))
+            self.operandsStmts.add(self.stmt)
+            return {objVar: obj}, self.operandsStmts
+        elif pred == ROOM['asFarenheit']:
+            if len(boundOperands) != 1:
+                raise ValueError(":asFarenheit takes 1 subject operand")
+            f = Literal(Decimal(self.numericNode(boundOperands[0])) * 9 / 5 + 32)
+            g = Graph()
+            g.add(self.stmt)
+
+            log.debug('made 1 st graph')
+            return {objVar: f}, g
+        else:
+            raise NotImplementedError()
+
+    def numericNode(self, n: Node):
+        if not isinstance(n, Literal):
+            raise TypeError(f'expected Literal, got {n=}')
+        val = n.toPython()
+        if not isinstance(val, (int, float, Decimal)):
+            raise TypeError(f'expected number, got {val=}')
+        return val
 
 
-def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[BindableTerm, Set[Node]]:
-    candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
-    lhsBnodes: Set[BNode] = set()
-    for lhsStmt in lhs:
-        for trueStmt in workingSet:
-            log.debug(f'            lhsStmt={graphDump([lhsStmt])} trueStmt={graphDump([trueStmt])}')
-            bindingsFromStatement: Dict[Variable, Set[Node]] = {}
-            for lhsTerm, trueTerm in zip(lhsStmt, trueStmt):
-                # log.debug(f' test {lhsTerm=} {trueTerm=}')
-                if isinstance(lhsTerm, BNode):
-                    lhsBnodes.add(lhsTerm)
-                elif isinstance(lhsTerm, Variable):
-                    bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm)
-                elif lhsTerm != trueTerm:
-                    break
-            else:
-                for v, vals in bindingsFromStatement.items():
-                    candidateTermMatches[v].update(vals)
-
-    for trueStmt in itertools.chain(workingSet, lhs):
-        for b in lhsBnodes:
-            for t in [trueStmt[0], trueStmt[2]]:
-                if isinstance(t, (URIRef, BNode)):
-                    candidateTermMatches[b].add(t)
-    return candidateTermMatches
-
-
-def inferredFuncObject(subj, pred, graph, bindings) -> Tuple[Literal, Graph]:
-    """return result from like `(1 2) math:sum ?out .` plus a graph of all the
-    statements involved in that function rule (including the bound answer"""
-    used = Graph()
-    if pred == ROOM['asFarenheit']:
-        obj = Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
-    elif pred == MATH['sum']:
-        operands, operandsStmts = parseList(graph, subj)
-        # shouldn't be redoing this here
-        operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
-        log.debug(f'                sum {[op.toPython() for op in operands]}')
-        used += operandsStmts
-        obj = Literal(sum(op.toPython() for op in operands))
-    else:
-        raise NotImplementedError(pred)
-
-    used.add((subj, pred, obj))
-    return obj, used
-
-
+# merge into evaluation, raising a Invalid for impossible stmts
 def mathTest(subj, pred, obj):
     x = subj.toPython()
     y = obj.toPython()
@@ -217,10 +328,11 @@
         """
         returns new graph of inferred statements.
         """
-        log.info(f'Begin inference of graph len={len(graph)} with rules len={len(self.rules)}:')
+        log.debug(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:')
 
         # everything that is true: the input graph, plus every rule conclusion we can make
-        workingSet = graphCopy(graph)
+        workingSet = Graph()
+        workingSet += graph
 
         # just the statements that came from rule RHS's.
         implied = ConjunctiveGraph()
@@ -228,52 +340,47 @@
         bailout_iterations = 100
         delta = 1
         while delta > 0 and bailout_iterations > 0:
-            log.debug(f'  * iteration ({bailout_iterations} left)')
+            log.debug(f'{INDENT*1}*iteration ({bailout_iterations} left)')
             bailout_iterations -= 1
             delta = -len(implied)
             self._iterateAllRules(workingSet, implied)
             delta += len(implied)
-            log.info(f'  this inference round added {delta} more implied stmts')
-        log.info(f'    {len(implied)} stmts implied:')
+            log.info(f'{INDENT*1} this inference round added {delta} more implied stmts')
+        log.info(f'{INDENT*0} {len(implied)} stmts implied:')
         for st in implied:
-            log.info(f'        {st}')
+            log.info(f'{INDENT*2} {st}')
         return implied
 
-    def _iterateAllRules(self, workingSet, implied):
+    def _iterateAllRules(self, workingSet: Graph, implied: Graph):
         for i, r in enumerate(self.rules):
-            log.debug(f'      workingSet: {graphDump(workingSet)}')
-            log.debug(f'      - applying rule {i}')
-            log.debug(f'        lhs: {graphDump(r[0])}')
-            log.debug(f'        rhs: {graphDump(r[2])}')
+            log.debug('')
+            log.debug(f'{INDENT*2} workingSet:')
+            for i, stmt in enumerate(sorted(workingSet)):
+                log.debug(f'{INDENT*3} ({i}) {stmt}')
+
+            log.debug('')
+            log.debug(f'{INDENT*2}-applying rule {i}')
+            log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}')
+            log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}')
             if r[1] == LOG['implies']:
-                applyRule(r[0], r[2], workingSet, implied)
+                applyRule(Lhs(r[0]), r[2], workingSet, implied)
             else:
-                log.info(f'   {r} not a rule?')
+                log.info(f'{INDENT*2} {r} not a rule?')
 
 
-def graphCopy(src: Graph) -> Graph:
-    if isinstance(src, ConjunctiveGraph):
-        out = ConjunctiveGraph()
-        out.addN(src.quads())
-        return out
-    else:
-        out = Graph()
-        for triple in src:
-            out.add(triple)
-        return out
-
-
-def applyRule(lhs: Graph, rhs: Graph, workingSet: Graph, implied: Graph):
-    for bindings in findCandidateBindings(lhs, workingSet):
-        log.debug(f'        rule gave {bindings=}')
-        for lhsBoundStmt in withBinding(lhs, bindings):
+def applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph):
+    for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])):
+        # log.debug(f'        rule gave {binding=}')
+        for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)):
             workingSet.add(lhsBoundStmt)
-        for newStmt in withBinding(rhs, bindings):
+        for newStmt in binding.apply(rhs):
             workingSet.add(newStmt)
             implied.add(newStmt)
 
 
 def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
+    """"Do like Collection(g, subj) but also return all the 
+    triples that are involved in the list"""
     out = []
     used = set()
     cur = subj
@@ -284,6 +391,7 @@
 
         next = graph.value(cur, RDF.rest)
         used.add((cur, RDF.rest, next))
+
         cur = next
         if cur == RDF.nil:
             break
@@ -317,23 +425,6 @@
     return orderedVars, orderedValueSets
 
 
-def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node, Graph]]:
-    for stmt in lhs:
-        if stmt[1] not in inferredFuncs:
-            continue
-        var = stmt[2]
-        if not isinstance(var, Variable):
-            continue
-
-        x = stmt[0]
-        if isinstance(x, Variable):
-            x = bindingsBefore[x]
-
-        resultObject, usedByFunc = inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
-
-        yield var, resultObject, usedByFunc
-
-
 def isStatic(spo: Triple):
     for t in spo:
         if isinstance(t, (Variable, BNode)):