changeset 1588:0757fafbfdab

WIP inferencer - partial var and function support
author drewp@bigasterisk.com
date Thu, 02 Sep 2021 01:58:31 -0700
parents 9a3a18c494f9
children 5c1055be3c36
files service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py
diffstat 2 files changed, 214 insertions(+), 24 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sun Aug 29 23:59:09 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Thu Sep 02 01:58:31 2021 -0700
@@ -2,12 +2,16 @@
 copied from reasoning 2021-08-29. probably same api. should
 be able to lib/ this out
 """
-
+import itertools
 import logging
-from typing import Dict, Tuple
 from dataclasses import dataclass
+from decimal import Decimal
+from typing import Dict, Iterator, List, Set, Tuple, cast
+from urllib.request import OpenerDirector
+
 from prometheus_client import Summary
-from rdflib import Graph, Namespace
+from rdflib import BNode, Graph, Literal, Namespace
+from rdflib.collection import Collection
 from rdflib.graph import ConjunctiveGraph
 from rdflib.term import Node, Variable
 
@@ -57,7 +61,7 @@
         while delta > 0 and bailout_iterations > 0:
             bailout_iterations -= 1
             delta = -len(implied)
-            self._iterateRules(workingSet, implied)
+            self._iterateAllRules(workingSet, implied)
             delta += len(implied)
             log.info(f'  this inference round added {delta} more implied stmts')
         log.info(f'{len(implied)} stmts implied:')
@@ -65,19 +69,161 @@
             log.info(f'  {st}')
         return implied
 
-    def _iterateRules(self, workingSet, implied):
+    def _iterateAllRules(self, workingSet, implied):
         for r in self.rules:
             if r[1] == LOG['implies']:
-                self._applyRule(r[0], r[2], workingSet, implied)
+                applyRule(r[0], r[2], workingSet, implied)
             else:
                 log.info(f'  {r} not a rule?')
 
-    def _applyRule(self, lhs, rhs, workingSet, implied):
-        containsSetup = self._containsSetup(lhs, workingSet)
-        if containsSetup:
-            for st in rhs:
-                workingSet.add(st)
-                implied.add(st)
+
+def applyRule(lhs: Graph, rhs: Graph, workingSet, implied):
+    for bindings in findCandidateBindings(lhs, workingSet):
+        log.debug(f' - rule gave {bindings=}')
+        for newStmt in withBinding(rhs, bindings):
+            workingSet.add(newStmt)
+            implied.add(newStmt)
+
+
+def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]:
+    varsToBind: Set[Variable] = set()
+    staticRuleStmts = []
+    for ruleStmt in lhs:
+        varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)]
+        varsToBind.update(varsInStmt)
+        if (not varsInStmt  # ok
+                and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
+           ):
+            staticRuleStmts.append(ruleStmt)
+
+    if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
+        log.debug('static shortcircuit')
+        return
+
+    # the total set of terms each variable could possibly match
+    candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
+
+    orderedVars, orderedValueSets = organize(candidateTermMatches)
+
+    log.debug(f'     {orderedVars=}')
+    log.debug(f'{orderedValueSets=}')
+
+    for perm in itertools.product(*orderedValueSets):
+        binding: Dict[Variable, Node] = dict(zip(orderedVars, perm))
+        log.debug(f'{binding=} but lets look for funcs')
+        for v, val in inferredFuncBindings(lhs, binding):  # loop this until it's done
+            log.debug(f'ifb tells us {v}={val}')
+            binding[v] = val
+        if not verifyBinding(lhs, binding, workingSet):  # fix this
+            log.debug(f'verify culls')
+            continue
+        yield binding
+
+
+def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]:
+    for stmt in lhs:
+        if stmt[1] not in inferredFuncs:
+            continue
+        if not isinstance(stmt[2], Variable):
+            continue
+
+        x = stmt[0]
+        if isinstance(x, Variable):
+            x = bindingsBefore[x]
+        yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
+
+
+def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]:
+    candidateTermMatches: Dict[Variable, Set[Node]] = {}
+
+    for r in lhs:
+        for w in workingSet:
+            bindingsFromStatement: Dict[Variable, Set[Node]] = {}
+            for rterm, wterm in zip(r, w):
+                if isinstance(rterm, Variable):
+                    bindingsFromStatement.setdefault(rterm, set()).add(wterm)
+                elif rterm != wterm:
+                    break
+            else:
+                for v, vals in bindingsFromStatement.items():
+                    candidateTermMatches.setdefault(v, set()).update(vals)
+    return candidateTermMatches
+
 
-    def _containsSetup(self, lhs, workingSet):
-        return all(st in workingSet for st in lhs)
+def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]:
+    for stmt in rhs:
+        stmt = list(stmt)
+        for i, t in enumerate(stmt):
+            if isinstance(t, Variable):
+                try:
+                    stmt[i] = bindings[t]
+                except KeyError:
+                    # stmt is from another rule that we're not applying right now
+                    break
+        else:
+            yield cast(Triple, stmt)
+
+
+def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool:
+    for stmt in withBinding(lhs, binding):
+        log.debug(f'lhs verify {stmt}')
+        if stmt[1] in filterFuncs:
+            if not mathTest(*stmt):
+                return False
+        elif stmt not in workingSet and stmt[1] not in inferredFuncs:
+            log.debug(f'  ver culls here')
+            return False
+    return True
+
+
+inferredFuncs = {
+    ROOM['asFarenheit'],
+    MATH['sum'],
+}
+filterFuncs = {
+    MATH['greaterThan'],
+}
+
+
+def inferredFuncObject(subj, pred, graph, bindings):
+    if pred == ROOM['asFarenheit']:
+        return Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
+    elif pred == MATH['sum']:
+        operands = Collection(graph, subj)
+        # shouldn't be redoing this here
+        operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
+        log.debug(f' sum {list(operands)}')
+        return Literal(sum(op.toPython() for op in operands))
+
+    else:
+        raise NotImplementedError(pred)
+
+
+def mathTest(subj, pred, obj):
+    x = subj.toPython()
+    y = obj.toPython()
+    if pred == MATH['greaterThan']:
+        return x > y
+    else:
+        raise NotImplementedError(pred)
+
+
+def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]:
+    items = list(candidateTermMatches.items())
+    items.sort()
+    orderedVars: List[Variable] = []
+    orderedValueSets: List[List[Node]] = []
+    for v, vals in items:
+        orderedVars.append(v)
+        orderedValues: List[Node] = list(vals)
+        orderedValues.sort(key=str)
+        orderedValueSets.append(orderedValues)
+
+    return orderedVars, orderedValueSets
+
+
+def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
+    for ruleStmt in staticRuleStmts:
+        if ruleStmt not in workingSet:
+            return True
+    return False
--- a/service/mqtt_to_rdf/inference_test.py	Sun Aug 29 23:59:09 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py	Thu Sep 02 01:58:31 2021 -0700
@@ -1,3 +1,6 @@
+"""
+also see https://github.com/w3c/N3/tree/master/tests/N3Tests
+"""
 import unittest
 
 from rdflib import ConjunctiveGraph, Namespace, Graph
@@ -10,7 +13,11 @@
 
 def N3(txt: str):
     g = ConjunctiveGraph()
-    prefix = "@prefix : <http://example.com/> .\n"
+    prefix = """
+@prefix : <http://example.com/> .
+@prefix room: <http://projects.bigasterisk.com/room/> .
+@prefix math: <http://www.w3.org/2000/10/swap/math#> .
+"""
     g.parse(StringInputSource((prefix + txt).encode('utf8')), format='n3')
     return g
 
@@ -68,25 +75,62 @@
         implied = inf.infer(N3(":a :b :c, :d ."))
         self.assertGraphEqual(implied, N3(":new :stmt :c, :d ."))
 
-    def testTwoRulesWithVars(self):
+    def testTwoRulesApplyIndependently(self):
+        inf = makeInferenceWithRules("""
+            { :a :b ?x . } => { :new :stmt ?x . } .
+            { :d :e ?y . } => { :new :stmt2 ?y . } .
+            """)
+        implied = inf.infer(N3(":a :b :c ."))
+        self.assertGraphEqual(implied, N3("""
+            :new :stmt :c .
+            """))
+        implied = inf.infer(N3(":a :b :c . :d :e :f ."))
+        self.assertGraphEqual(implied, N3("""
+            :new :stmt :c .
+            :new :stmt2 :f .
+            """))
+
+    def testOneRuleActivatesAnother(self):
         inf = makeInferenceWithRules("""
-        { :a :b ?x . } => { :new :stmt ?x } .
-        { ?y :stmt ?z . } => { :new :stmt2 ?z }
-        """)
+            { :a :b ?x . } => { :new :stmt ?x . } .
+            { ?y :stmt ?z . } => { :new :stmt2 ?y . } .
+            """)
         implied = inf.infer(N3(":a :b :c ."))
-        self.assertGraphEqual(implied, N3(":new :stmt :c; :stmt2 :new ."))
+        self.assertGraphEqual(implied, N3("""
+            :new :stmt :c .
+            :new :stmt2 :new .
+            """))
+
+    def testVarLinksTwoStatements(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . :d :e ?x } => { :new :stmt ?x } .")
+        implied = inf.infer(N3(":a :b :c  ."))
+        self.assertGraphEqual(implied, N3(""))
+        implied = inf.infer(N3(":a :b :c . :d :e :f ."))
+        self.assertGraphEqual(implied, N3(""))
+        implied = inf.infer(N3(":a :b :c . :d :e :c ."))
+        self.assertGraphEqual(implied, N3(":new :stmt :c ."))
+
+    def testRuleMatchesStaticStatement(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . :a :b :c . } => { :new :stmt ?x } .")
+        implied = inf.infer(N3(":a :b :c  ."))
+        self.assertGraphEqual(implied, N3(":new :stmt :c ."))
 
 
 class TestInferenceWithMathFunctions(WithGraphEqual):
 
-    def test1(self):
+    def testBoolFilter(self):
         inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .")
         self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(""))
         self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3(""))
-        self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt :a ."))
+        self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 ."))
+
+    def testStatementGeneratingRule(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))
+
 
 class TestInferenceWithCustomFunctions(WithGraphEqual):
 
     def testAsFarenheit(self):
-        inf = makeInferenceWithRules("{ :a :b ?x . ?x :asFarenheit ?f } => { :new :stmt ?f } .")
-        self.assertGraphEqual(inf.infer(N3(":a :b 0 .")), N3(":new :stmt -32 ."))
+        inf = makeInferenceWithRules("{ :a :b ?x . ?x room:asFarenheit ?f } => { :new :stmt ?f } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 12 .")), N3(":new :stmt 53.6 ."))