changeset 1590:327202020892

WIP inference- getting into more degenerate test cases
author drewp@bigasterisk.com
date Thu, 02 Sep 2021 23:20:55 -0700
parents 5c1055be3c36
children 668958454ae2
files service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py
diffstat 2 files changed, 204 insertions(+), 57 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Thu Sep 02 13:39:27 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Thu Sep 02 23:20:55 2021 -0700
@@ -2,6 +2,7 @@
 copied from reasoning 2021-08-29. probably same api. should
 be able to lib/ this out
 """
+from collections import defaultdict
 import itertools
 import logging
 from dataclasses import dataclass
@@ -10,7 +11,7 @@
 from urllib.request import OpenerDirector
 
 from prometheus_client import Summary
-from rdflib import BNode, Graph, Literal, Namespace
+from rdflib import BNode, Graph, Literal, Namespace, URIRef, RDF
 from rdflib.collection import Collection
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
 from rdflib.term import Node, Variable
@@ -92,7 +93,12 @@
         return out
 
 
-def graphDump(g: Graph):
+def graphDump(g: Union[Graph, List[Triple]]):
+    if not isinstance(g, Graph):
+        g2 = Graph()
+        for stmt in g:
+            g2.add(stmt)
+        g = g2
     g.bind('', ROOM)
     g.bind('ex', Namespace('http://example.com/'))
     lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
@@ -110,27 +116,27 @@
             implied.add(newStmt)
 
 
-def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]:
+def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[BindableTerm, Node]]:
     """bindings that fit the LHS of a rule, using statements from workingSet and functions
     from LHS"""
-    varsToBind: Set[Variable] = set()
+    varsToBind: Set[BindableTerm] = set()
     staticRuleStmts = Graph()
     for ruleStmt in lhs:
-        varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)]
+        varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
         varsToBind.update(varsInStmt)
         if (not varsInStmt  # ok
                 #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
            ):
             staticRuleStmts.add(ruleStmt)
 
-    log.debug(f'        {varsToBind=}')
+    log.debug(f'        varsToBind: {sorted(varsToBind)}')
 
     if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
         log.debug(f'    someStaticStmtDoesntMatch: {graphDump(staticRuleStmts)}')
         return
 
     # the total set of terms each variable could possibly match
-    candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
+    candidateTermMatches: Dict[BindableTerm, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
 
     orderedVars, orderedValueSets = organize(candidateTermMatches)
 
@@ -138,76 +144,113 @@
     log.debug(f'            {orderedVars=}')
     log.debug(f'            {orderedValueSets=}')
 
-    for perm in itertools.product(*orderedValueSets):
-        binding: Dict[Variable, Node] = dict(zip(orderedVars, perm))
-        log.debug(f'            {binding=} but lets look for funcs')
-        for v, val in inferredFuncBindings(lhs, binding):  # loop this until it's done
-            log.debug(f'        ifb tells us {v}={val}')
+    for i, perm in enumerate(itertools.product(*orderedValueSets)):
+        binding: Dict[BindableTerm, Node] = dict(zip(orderedVars, perm))
+        log.debug('')
+        log.debug(f'            ** trying {binding=}')
+        usedByFuncs = Graph()
+        for v, val, used in inferredFuncBindings(lhs, binding):  # loop this until it's done
+            log.debug(f'            inferredFuncBindings tells us {v}={val}')
             binding[v] = val
-        if not verifyBinding(lhs, binding, workingSet):  # fix this
-            log.debug(f'        verify culls')
+            usedByFuncs += used
+        if len(binding) != len(varsToBind):
+            log.debug(f'                binding is incomplete, needs {varsToBind}')
+
+            continue
+        if not verifyBinding(lhs, binding, workingSet, usedByFuncs):  # fix this
+            log.debug(f'            this binding did not verify')
             continue
         yield binding
 
 
-def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]:
+def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node, Graph]]:
     for stmt in lhs:
         if stmt[1] not in inferredFuncs:
             continue
-        if not isinstance(stmt[2], Variable):
+        var = stmt[2]
+        if not isinstance(var, Variable):
             continue
 
         x = stmt[0]
         if isinstance(x, Variable):
             x = bindingsBefore[x]
-        yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
+
+        resultObject, usedByFunc = inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
+
+        yield var, resultObject, usedByFunc
 
 
-def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]:
-    candidateTermMatches: Dict[Variable, Set[Node]] = {}
-
+def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[BindableTerm, Set[Node]]:
+    candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
+    lhsBnodes: Set[BNode] = set()
     for lhsStmt in lhs:
         for trueStmt in workingSet:
-            log.debug(f'{lhsStmt=} {trueStmt=}')
+            log.debug(f'            lhsStmt={graphDump([lhsStmt])} trueStmt={graphDump([trueStmt])}')
             bindingsFromStatement: Dict[Variable, Set[Node]] = {}
             for lhsTerm, trueTerm in zip(lhsStmt, trueStmt):
-                log.debug(f' test {lhsTerm=} {trueTerm=}')
-                if isinstance(lhsTerm, Variable):
+                # log.debug(f' test {lhsTerm=} {trueTerm=}')
+                if isinstance(lhsTerm, BNode):
+                    lhsBnodes.add(lhsTerm)
+                elif isinstance(lhsTerm, Variable):
                     bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm)
                 elif lhsTerm != trueTerm:
                     break
             else:
                 for v, vals in bindingsFromStatement.items():
-                    candidateTermMatches.setdefault(v, set()).update(vals)
+                    candidateTermMatches[v].update(vals)
+
+    for trueStmt in itertools.chain(workingSet, lhs):
+        for b in lhsBnodes:
+            for t in [trueStmt[0], trueStmt[2]]:
+                if isinstance(t, (URIRef, BNode)):
+                    candidateTermMatches[b].add(t)
     return candidateTermMatches
 
 
-def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]:
-    for stmt in rhs:
+def withBinding(toBind: Graph, bindings: Dict[BindableTerm, Node], includeStaticStmts=True) -> Iterator[Triple]:
+    for stmt in toBind:
         stmt = list(stmt)
-        for i, t in enumerate(stmt):
-            if isinstance(t, Variable):
-                try:
-                    stmt[i] = bindings[t]
-                except KeyError:
-                    # stmt is from another rule that we're not applying right now
-                    break
+        static = True
+        for i, term in enumerate(stmt):
+            if isinstance(term, (Variable, BNode)):
+                stmt[i] = bindings[term]
+                static = False
         else:
-            yield cast(Triple, stmt)
+            if includeStaticStmts or not static:
+                yield cast(Triple, stmt)
 
 
-def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool:
-    """can this lhs be true all at once?"""
-    for stmt in withBinding(lhs, binding):
-        log.debug(f'    lhs verify {stmt}')
+def verifyBinding(lhs: Graph, binding: Dict[BindableTerm, Node], workingSet: Graph, usedByFuncs: Graph) -> bool:
+    """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
+    log.debug(f'                verify all bindings against this lhs:')
+    boundLhs = list(withBinding(lhs, binding))
+    for stmt in boundLhs:
+        log.debug(f'                    {stmt}')
+
+    log.debug(f'                and against this workingSet:')
+    for stmt in workingSet:
+        log.debug(f'                    {stmt}')
+
+    log.debug(f'                ignoring these usedByFuncs:')
+    boundUsedByFuncs = list(withBinding(usedByFuncs, binding))
+    for stmt in boundUsedByFuncs:
+        log.debug(f'                    {stmt}')
+    # The static stmts in lhs are obviously going
+    # to match- we only need to verify the ones
+    # that needed bindings.
+    for stmt in boundLhs:  #withBinding(lhs, binding, includeStaticStmts=False):
+        log.debug(f'                check for {stmt}')
+
         if stmt[1] in filterFuncs:
             if not mathTest(*stmt):
+                log.debug(f'                    binding was invalid because {stmt}) is not true')
                 return False
-        elif (stmt not in workingSet  # not previously true
-              and stmt not in lhs  # not from the bindings in this rule
-              and stmt[1] not in inferredFuncs  # not a function stmt (maybe this is wrong)
-             ):
-            log.debug(f'    ver culls here')
+        elif stmt in boundUsedByFuncs:
+            pass
+        elif stmt in workingSet:
+            pass
+        else:
+            log.debug(f'                    binding was invalid because {stmt}) cannot be true')
             return False
     return True
 
@@ -221,19 +264,49 @@
 }
 
 
-def inferredFuncObject(subj, pred, graph, bindings):
+def isStatic(spo: Triple):
+    for t in spo:
+        if isinstance(t, (Variable, BNode)):
+            return False
+    return True
+
+
+def inferredFuncObject(subj, pred, graph, bindings) -> Tuple[Literal, Graph]:
+    """return result from like `(1 2) math:sum ?out .` plus a graph of all the
+    statements involved in that function rule (including the bound answer"""
+    used = Graph()
     if pred == ROOM['asFarenheit']:
-        return Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
+        obj = Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
     elif pred == MATH['sum']:
-        operands = Collection(graph, subj)
+        operands, operandsStmts = parseList(graph, subj)
         # shouldn't be redoing this here
         operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
-        log.debug(f' sum {list(operands)}')
-        return Literal(sum(op.toPython() for op in operands))
-
+        log.debug(f'                sum {[op.toPython() for op in operands]}')
+        used += operandsStmts
+        obj = Literal(sum(op.toPython() for op in operands))
     else:
         raise NotImplementedError(pred)
 
+    used.add((subj, pred, obj))
+    return obj, used
+
+
+def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
+    out = []
+    used = set()
+    cur = subj
+    while True:
+        # bug: mishandles empty list
+        out.append(graph.value(cur, RDF.first))
+        used.add((cur, RDF.first, out[-1]))
+
+        next = graph.value(cur, RDF.rest)
+        used.add((cur, RDF.rest, next))
+        cur = next
+        if cur == RDF.nil:
+            break
+    return out, used
+
 
 def mathTest(subj, pred, obj):
     x = subj.toPython()
@@ -244,10 +317,10 @@
         raise NotImplementedError(pred)
 
 
-def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]:
+def organize(candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Tuple[List[BindableTerm], List[List[Node]]]:
     items = list(candidateTermMatches.items())
     items.sort()
-    orderedVars: List[Variable] = []
+    orderedVars: List[BindableTerm] = []
     orderedValueSets: List[List[Node]] = []
     for v, vals in items:
         orderedVars.append(v)
--- a/service/mqtt_to_rdf/inference_test.py	Thu Sep 02 13:39:27 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py	Thu Sep 02 23:20:55 2021 -0700
@@ -2,12 +2,64 @@
 also see https://github.com/w3c/N3/tree/master/tests/N3Tests
 """
 import unittest
-
-from rdflib import ConjunctiveGraph, Namespace, Graph
+import itertools
+from rdflib import ConjunctiveGraph, Namespace, Graph, BNode
 from rdflib.parser import StringInputSource
 
 from inference import Inference
 
+
+def patchSlimReprs():
+    import rdflib.term
+
+    def ur(self):
+        clsName = "U" if self.__class__ is rdflib.term.URIRef else self.__class__.__name__
+        return """%s(%s)""" % (clsName, super(rdflib.term.URIRef, self).__repr__())
+
+    rdflib.term.URIRef.__repr__ = ur
+
+    def br(self):
+        clsName = "BNode" if self.__class__ is rdflib.term.BNode else self.__class__.__name__
+        return """%s(%s)""" % (clsName, super(rdflib.term.BNode, self).__repr__())
+
+    rdflib.term.BNode.__repr__ = br
+
+    def vr(self):
+        clsName = "V" if self.__class__ is rdflib.term.Variable else self.__class__.__name__
+        return """%s(%s)""" % (clsName, super(rdflib.term.Variable, self).__repr__())
+
+    rdflib.term.Variable.__repr__ = vr
+
+
+patchSlimReprs()
+
+
+def patchBnodeCounter():
+    import rdflib.term
+    serial = itertools.count()
+
+    def n(cls, value=None, _sn_gen='', _prefix='') -> BNode:
+        if value is None:
+            value = 'N-%s' % next(serial)
+        return rdflib.term.Identifier.__new__(cls, value)
+
+    rdflib.term.BNode.__new__ = n
+
+    import rdflib.plugins.parsers.notation3
+
+    def newBlankNode(self, uri=None, why=None):
+        if uri is None:
+            self.counter += 1
+            bn = BNode('f-%s-%s' % (self.number, self.counter))
+        else:
+            bn = BNode(uri.split('#').pop().replace('_', 'b'))
+        return bn
+
+    rdflib.plugins.parsers.notation3.Formula.newBlankNode = newBlankNode
+
+
+patchBnodeCounter()
+
 ROOM = Namespace('http://projects.bigasterisk.com/room/')
 
 
@@ -117,11 +169,29 @@
 
 
 class TestBnodeMatching(WithGraphEqual):
-    def test1(self):
+
+    def testRuleBnodeBindsToInputBnode(self):
         inf = makeInferenceWithRules("{ [ :a :b ] . } => { :new :stmt :here } .")
         implied = inf.infer(N3("[ :a :b ] ."))
         self.assertGraphEqual(implied, N3(":new :stmt :here ."))
 
+    def testRuleVarBindsToInputBNode(self):
+        inf = makeInferenceWithRules("{ ?z :a :b  . } => { :new :stmt :here } .")
+        implied = inf.infer(N3("[] :a :b ."))
+        self.assertGraphEqual(implied, N3(":new :stmt :here ."))
+
+
+class TestSelfFulfillingRule(WithGraphEqual):
+
+    def test1(self):
+        inf = makeInferenceWithRules("{ } => { :new :stmt :x } .")
+        self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt :x ."))
+        self.assertGraphEqual(inf.infer(N3(":any :any :any .")), N3(":new :stmt :x ."))
+
+    def test2(self):
+        inf = makeInferenceWithRules("{ (2) math:sum ?x } => { :new :stmt ?x } .")
+        self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt 2 ."))
+
 
 class TestInferenceWithMathFunctions(WithGraphEqual):
 
@@ -131,9 +201,13 @@
         self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3(""))
         self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 ."))
 
-    # def testStatementGeneratingRule(self):
-    #     inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
-    #     self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))
+    def testStatementGeneratingRule(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))
+
+    def test3Operands(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 ."))
 
 
 class TestInferenceWithCustomFunctions(WithGraphEqual):