# HG changeset patch
# User drewp@bigasterisk.com
# Date 1630573111 25200
# Node ID 0757fafbfdabea6dd0ec36312fcd0a2952579dbd
# Parent 9a3a18c494f995ff6a0a33ee6145382343ccf23e
WIP inferencer - partial var and function support
diff -r 9a3a18c494f9 -r 0757fafbfdab service/mqtt_to_rdf/inference.py
--- a/service/mqtt_to_rdf/inference.py Sun Aug 29 23:59:09 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py Thu Sep 02 01:58:31 2021 -0700
@@ -2,12 +2,16 @@
copied from reasoning 2021-08-29. probably same api. should
be able to lib/ this out
"""
-
+import itertools
import logging
-from typing import Dict, Tuple
from dataclasses import dataclass
+from decimal import Decimal
+from typing import Dict, Iterator, List, Set, Tuple, cast
+from urllib.request import OpenerDirector
+
from prometheus_client import Summary
-from rdflib import Graph, Namespace
+from rdflib import BNode, Graph, Literal, Namespace
+from rdflib.collection import Collection
from rdflib.graph import ConjunctiveGraph
from rdflib.term import Node, Variable
@@ -57,7 +61,7 @@
while delta > 0 and bailout_iterations > 0:
bailout_iterations -= 1
delta = -len(implied)
- self._iterateRules(workingSet, implied)
+ self._iterateAllRules(workingSet, implied)
delta += len(implied)
log.info(f' this inference round added {delta} more implied stmts')
log.info(f'{len(implied)} stmts implied:')
@@ -65,19 +69,161 @@
log.info(f' {st}')
return implied
- def _iterateRules(self, workingSet, implied):
+ def _iterateAllRules(self, workingSet, implied):
for r in self.rules:
if r[1] == LOG['implies']:
- self._applyRule(r[0], r[2], workingSet, implied)
+ applyRule(r[0], r[2], workingSet, implied)
else:
log.info(f' {r} not a rule?')
- def _applyRule(self, lhs, rhs, workingSet, implied):
- containsSetup = self._containsSetup(lhs, workingSet)
- if containsSetup:
- for st in rhs:
- workingSet.add(st)
- implied.add(st)
+
+def applyRule(lhs: Graph, rhs: Graph, workingSet, implied):
+ for bindings in findCandidateBindings(lhs, workingSet):
+ log.debug(f' - rule gave {bindings=}')
+ for newStmt in withBinding(rhs, bindings):
+ workingSet.add(newStmt)
+ implied.add(newStmt)
+
+
+def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]:
+ varsToBind: Set[Variable] = set()
+ staticRuleStmts = []
+ for ruleStmt in lhs:
+ varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)]
+ varsToBind.update(varsInStmt)
+ if (not varsInStmt # ok
+ and not any(isinstance(t, BNode) for t in ruleStmt) # approx
+ ):
+ staticRuleStmts.append(ruleStmt)
+
+ if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
+ log.debug('static shortcircuit')
+ return
+
+ # the total set of terms each variable could possibly match
+ candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
+
+ orderedVars, orderedValueSets = organize(candidateTermMatches)
+
+ log.debug(f' {orderedVars=}')
+ log.debug(f'{orderedValueSets=}')
+
+ for perm in itertools.product(*orderedValueSets):
+ binding: Dict[Variable, Node] = dict(zip(orderedVars, perm))
+ log.debug(f'{binding=} but lets look for funcs')
+ for v, val in inferredFuncBindings(lhs, binding): # loop this until it's done
+ log.debug(f'ifb tells us {v}={val}')
+ binding[v] = val
+ if not verifyBinding(lhs, binding, workingSet): # fix this
+ log.debug(f'verify culls')
+ continue
+ yield binding
+
+
+def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]:
+ for stmt in lhs:
+ if stmt[1] not in inferredFuncs:
+ continue
+ if not isinstance(stmt[2], Variable):
+ continue
+
+ x = stmt[0]
+ if isinstance(x, Variable):
+ x = bindingsBefore[x]
+ yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
+
+
+def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]:
+ candidateTermMatches: Dict[Variable, Set[Node]] = {}
+
+ for r in lhs:
+ for w in workingSet:
+ bindingsFromStatement: Dict[Variable, Set[Node]] = {}
+ for rterm, wterm in zip(r, w):
+ if isinstance(rterm, Variable):
+ bindingsFromStatement.setdefault(rterm, set()).add(wterm)
+ elif rterm != wterm:
+ break
+ else:
+ for v, vals in bindingsFromStatement.items():
+ candidateTermMatches.setdefault(v, set()).update(vals)
+ return candidateTermMatches
+
- def _containsSetup(self, lhs, workingSet):
- return all(st in workingSet for st in lhs)
+def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]:
+ for stmt in rhs:
+ stmt = list(stmt)
+ for i, t in enumerate(stmt):
+ if isinstance(t, Variable):
+ try:
+ stmt[i] = bindings[t]
+ except KeyError:
+ # stmt is from another rule that we're not applying right now
+ break
+ else:
+ yield cast(Triple, stmt)
+
+
+def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool:
+ for stmt in withBinding(lhs, binding):
+ log.debug(f'lhs verify {stmt}')
+ if stmt[1] in filterFuncs:
+ if not mathTest(*stmt):
+ return False
+ elif stmt not in workingSet and stmt[1] not in inferredFuncs:
+ log.debug(f' ver culls here')
+ return False
+ return True
+
+
+inferredFuncs = {
+ ROOM['asFarenheit'],
+ MATH['sum'],
+}
+filterFuncs = {
+ MATH['greaterThan'],
+}
+
+
+def inferredFuncObject(subj, pred, graph, bindings):
+ if pred == ROOM['asFarenheit']:
+ return Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
+ elif pred == MATH['sum']:
+ operands = Collection(graph, subj)
+ # shouldn't be redoing this here
+ operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
+ log.debug(f' sum {list(operands)}')
+ return Literal(sum(op.toPython() for op in operands))
+
+ else:
+ raise NotImplementedError(pred)
+
+
+def mathTest(subj, pred, obj):
+ x = subj.toPython()
+ y = obj.toPython()
+ if pred == MATH['greaterThan']:
+ return x > y
+ else:
+ raise NotImplementedError(pred)
+
+
+def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]:
+ items = list(candidateTermMatches.items())
+ items.sort()
+ orderedVars: List[Variable] = []
+ orderedValueSets: List[List[Node]] = []
+ for v, vals in items:
+ orderedVars.append(v)
+ orderedValues: List[Node] = list(vals)
+ orderedValues.sort(key=str)
+ orderedValueSets.append(orderedValues)
+
+ return orderedVars, orderedValueSets
+
+
+def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
+ for ruleStmt in staticRuleStmts:
+ if ruleStmt not in workingSet:
+ return True
+ return False
diff -r 9a3a18c494f9 -r 0757fafbfdab service/mqtt_to_rdf/inference_test.py
--- a/service/mqtt_to_rdf/inference_test.py Sun Aug 29 23:59:09 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py Thu Sep 02 01:58:31 2021 -0700
@@ -1,3 +1,6 @@
+"""
+also see https://github.com/w3c/N3/tree/master/tests/N3Tests
+"""
import unittest
from rdflib import ConjunctiveGraph, Namespace, Graph
@@ -10,7 +13,11 @@
def N3(txt: str):
g = ConjunctiveGraph()
- prefix = "@prefix : .\n"
+ prefix = """
+@prefix : .
+@prefix room: .
+@prefix math: .
+"""
g.parse(StringInputSource((prefix + txt).encode('utf8')), format='n3')
return g
@@ -68,25 +75,62 @@
implied = inf.infer(N3(":a :b :c, :d ."))
self.assertGraphEqual(implied, N3(":new :stmt :c, :d ."))
- def testTwoRulesWithVars(self):
+ def testTwoRulesApplyIndependently(self):
+ inf = makeInferenceWithRules("""
+ { :a :b ?x . } => { :new :stmt ?x . } .
+ { :d :e ?y . } => { :new :stmt2 ?y . } .
+ """)
+ implied = inf.infer(N3(":a :b :c ."))
+ self.assertGraphEqual(implied, N3("""
+ :new :stmt :c .
+ """))
+ implied = inf.infer(N3(":a :b :c . :d :e :f ."))
+ self.assertGraphEqual(implied, N3("""
+ :new :stmt :c .
+ :new :stmt2 :f .
+ """))
+
+ def testOneRuleActivatesAnother(self):
inf = makeInferenceWithRules("""
- { :a :b ?x . } => { :new :stmt ?x } .
- { ?y :stmt ?z . } => { :new :stmt2 ?z }
- """)
+ { :a :b ?x . } => { :new :stmt ?x . } .
+ { ?y :stmt ?z . } => { :new :stmt2 ?y . } .
+ """)
implied = inf.infer(N3(":a :b :c ."))
- self.assertGraphEqual(implied, N3(":new :stmt :c; :stmt2 :new ."))
+ self.assertGraphEqual(implied, N3("""
+ :new :stmt :c .
+ :new :stmt2 :new .
+ """))
+
+ def testVarLinksTwoStatements(self):
+ inf = makeInferenceWithRules("{ :a :b ?x . :d :e ?x } => { :new :stmt ?x } .")
+ implied = inf.infer(N3(":a :b :c ."))
+ self.assertGraphEqual(implied, N3(""))
+ implied = inf.infer(N3(":a :b :c . :d :e :f ."))
+ self.assertGraphEqual(implied, N3(""))
+ implied = inf.infer(N3(":a :b :c . :d :e :c ."))
+ self.assertGraphEqual(implied, N3(":new :stmt :c ."))
+
+ def testRuleMatchesStaticStatement(self):
+ inf = makeInferenceWithRules("{ :a :b ?x . :a :b :c . } => { :new :stmt ?x } .")
+ implied = inf.infer(N3(":a :b :c ."))
+ self.assertGraphEqual(implied, N3(":new :stmt :c ."))
class TestInferenceWithMathFunctions(WithGraphEqual):
- def test1(self):
+ def testBoolFilter(self):
inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .")
self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(""))
self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3(""))
- self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt :a ."))
+ self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 ."))
+
+ def testStatementGeneratingRule(self):
+ inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
+ self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))
+
class TestInferenceWithCustomFunctions(WithGraphEqual):
def testAsFarenheit(self):
- inf = makeInferenceWithRules("{ :a :b ?x . ?x :asFarenheit ?f } => { :new :stmt ?f } .")
- self.assertGraphEqual(inf.infer(N3(":a :b 0 .")), N3(":new :stmt -32 ."))
+ inf = makeInferenceWithRules("{ :a :b ?x . ?x room:asFarenheit ?f } => { :new :stmt ?f } .")
+ self.assertGraphEqual(inf.infer(N3(":a :b 12 .")), N3(":new :stmt 53.6 ."))