# HG changeset patch # User drewp@bigasterisk.com # Date 1630573111 25200 # Node ID 0757fafbfdabea6dd0ec36312fcd0a2952579dbd # Parent 9a3a18c494f995ff6a0a33ee6145382343ccf23e WIP inferencer - partial var and function support diff -r 9a3a18c494f9 -r 0757fafbfdab service/mqtt_to_rdf/inference.py --- a/service/mqtt_to_rdf/inference.py Sun Aug 29 23:59:09 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Thu Sep 02 01:58:31 2021 -0700 @@ -2,12 +2,16 @@ copied from reasoning 2021-08-29. probably same api. should be able to lib/ this out """ - +import itertools import logging -from typing import Dict, Tuple from dataclasses import dataclass +from decimal import Decimal +from typing import Dict, Iterator, List, Set, Tuple, cast +from urllib.request import OpenerDirector + from prometheus_client import Summary -from rdflib import Graph, Namespace +from rdflib import BNode, Graph, Literal, Namespace +from rdflib.collection import Collection from rdflib.graph import ConjunctiveGraph from rdflib.term import Node, Variable @@ -57,7 +61,7 @@ while delta > 0 and bailout_iterations > 0: bailout_iterations -= 1 delta = -len(implied) - self._iterateRules(workingSet, implied) + self._iterateAllRules(workingSet, implied) delta += len(implied) log.info(f' this inference round added {delta} more implied stmts') log.info(f'{len(implied)} stmts implied:') @@ -65,19 +69,161 @@ log.info(f' {st}') return implied - def _iterateRules(self, workingSet, implied): + def _iterateAllRules(self, workingSet, implied): for r in self.rules: if r[1] == LOG['implies']: - self._applyRule(r[0], r[2], workingSet, implied) + applyRule(r[0], r[2], workingSet, implied) else: log.info(f' {r} not a rule?') - def _applyRule(self, lhs, rhs, workingSet, implied): - containsSetup = self._containsSetup(lhs, workingSet) - if containsSetup: - for st in rhs: - workingSet.add(st) - implied.add(st) + +def applyRule(lhs: Graph, rhs: Graph, workingSet, implied): + for bindings in findCandidateBindings(lhs, workingSet): + log.debug(f' - rule gave {bindings=}') + for newStmt in withBinding(rhs, bindings): + workingSet.add(newStmt) + implied.add(newStmt) + + +def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]: + varsToBind: Set[Variable] = set() + staticRuleStmts = [] + for ruleStmt in lhs: + varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)] + varsToBind.update(varsInStmt) + if (not varsInStmt # ok + and not any(isinstance(t, BNode) for t in ruleStmt) # approx + ): + staticRuleStmts.append(ruleStmt) + + if someStaticStmtDoesntMatch(staticRuleStmts, workingSet): + log.debug('static shortcircuit') + return + + # the total set of terms each variable could possibly match + candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet) + + orderedVars, orderedValueSets = organize(candidateTermMatches) + + log.debug(f' {orderedVars=}') + log.debug(f'{orderedValueSets=}') + + for perm in itertools.product(*orderedValueSets): + binding: Dict[Variable, Node] = dict(zip(orderedVars, perm)) + log.debug(f'{binding=} but lets look for funcs') + for v, val in inferredFuncBindings(lhs, binding): # loop this until it's done + log.debug(f'ifb tells us {v}={val}') + binding[v] = val + if not verifyBinding(lhs, binding, workingSet): # fix this + log.debug(f'verify culls') + continue + yield binding + + +def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]: + for stmt in lhs: + if stmt[1] not in inferredFuncs: + continue + if not isinstance(stmt[2], Variable): + continue + + x = stmt[0] + if isinstance(x, Variable): + x = bindingsBefore[x] + yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore) + + +def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]: + candidateTermMatches: Dict[Variable, Set[Node]] = {} + + for r in lhs: + for w in workingSet: + bindingsFromStatement: Dict[Variable, Set[Node]] = {} + for rterm, wterm in zip(r, w): + if isinstance(rterm, Variable): + bindingsFromStatement.setdefault(rterm, set()).add(wterm) + elif rterm != wterm: + break + else: + for v, vals in bindingsFromStatement.items(): + candidateTermMatches.setdefault(v, set()).update(vals) + return candidateTermMatches + - def _containsSetup(self, lhs, workingSet): - return all(st in workingSet for st in lhs) +def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]: + for stmt in rhs: + stmt = list(stmt) + for i, t in enumerate(stmt): + if isinstance(t, Variable): + try: + stmt[i] = bindings[t] + except KeyError: + # stmt is from another rule that we're not applying right now + break + else: + yield cast(Triple, stmt) + + +def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool: + for stmt in withBinding(lhs, binding): + log.debug(f'lhs verify {stmt}') + if stmt[1] in filterFuncs: + if not mathTest(*stmt): + return False + elif stmt not in workingSet and stmt[1] not in inferredFuncs: + log.debug(f' ver culls here') + return False + return True + + +inferredFuncs = { + ROOM['asFarenheit'], + MATH['sum'], +} +filterFuncs = { + MATH['greaterThan'], +} + + +def inferredFuncObject(subj, pred, graph, bindings): + if pred == ROOM['asFarenheit']: + return Literal(Decimal(subj.toPython()) * 9 / 5 + 32) + elif pred == MATH['sum']: + operands = Collection(graph, subj) + # shouldn't be redoing this here + operands = [bindings[o] if isinstance(o, Variable) else o for o in operands] + log.debug(f' sum {list(operands)}') + return Literal(sum(op.toPython() for op in operands)) + + else: + raise NotImplementedError(pred) + + +def mathTest(subj, pred, obj): + x = subj.toPython() + y = obj.toPython() + if pred == MATH['greaterThan']: + return x > y + else: + raise NotImplementedError(pred) + + +def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]: + items = list(candidateTermMatches.items()) + items.sort() + orderedVars: List[Variable] = [] + orderedValueSets: List[List[Node]] = [] + for v, vals in items: + orderedVars.append(v) + orderedValues: List[Node] = list(vals) + orderedValues.sort(key=str) + orderedValueSets.append(orderedValues) + + return orderedVars, orderedValueSets + + +def someStaticStmtDoesntMatch(staticRuleStmts, workingSet): + for ruleStmt in staticRuleStmts: + if ruleStmt not in workingSet: + return True + return False diff -r 9a3a18c494f9 -r 0757fafbfdab service/mqtt_to_rdf/inference_test.py --- a/service/mqtt_to_rdf/inference_test.py Sun Aug 29 23:59:09 2021 -0700 +++ b/service/mqtt_to_rdf/inference_test.py Thu Sep 02 01:58:31 2021 -0700 @@ -1,3 +1,6 @@ +""" +also see https://github.com/w3c/N3/tree/master/tests/N3Tests +""" import unittest from rdflib import ConjunctiveGraph, Namespace, Graph @@ -10,7 +13,11 @@ def N3(txt: str): g = ConjunctiveGraph() - prefix = "@prefix : .\n" + prefix = """ +@prefix : . +@prefix room: . +@prefix math: . +""" g.parse(StringInputSource((prefix + txt).encode('utf8')), format='n3') return g @@ -68,25 +75,62 @@ implied = inf.infer(N3(":a :b :c, :d .")) self.assertGraphEqual(implied, N3(":new :stmt :c, :d .")) - def testTwoRulesWithVars(self): + def testTwoRulesApplyIndependently(self): + inf = makeInferenceWithRules(""" + { :a :b ?x . } => { :new :stmt ?x . } . + { :d :e ?y . } => { :new :stmt2 ?y . } . + """) + implied = inf.infer(N3(":a :b :c .")) + self.assertGraphEqual(implied, N3(""" + :new :stmt :c . + """)) + implied = inf.infer(N3(":a :b :c . :d :e :f .")) + self.assertGraphEqual(implied, N3(""" + :new :stmt :c . + :new :stmt2 :f . + """)) + + def testOneRuleActivatesAnother(self): inf = makeInferenceWithRules(""" - { :a :b ?x . } => { :new :stmt ?x } . - { ?y :stmt ?z . } => { :new :stmt2 ?z } - """) + { :a :b ?x . } => { :new :stmt ?x . } . + { ?y :stmt ?z . } => { :new :stmt2 ?y . } . + """) implied = inf.infer(N3(":a :b :c .")) - self.assertGraphEqual(implied, N3(":new :stmt :c; :stmt2 :new .")) + self.assertGraphEqual(implied, N3(""" + :new :stmt :c . + :new :stmt2 :new . + """)) + + def testVarLinksTwoStatements(self): + inf = makeInferenceWithRules("{ :a :b ?x . :d :e ?x } => { :new :stmt ?x } .") + implied = inf.infer(N3(":a :b :c .")) + self.assertGraphEqual(implied, N3("")) + implied = inf.infer(N3(":a :b :c . :d :e :f .")) + self.assertGraphEqual(implied, N3("")) + implied = inf.infer(N3(":a :b :c . :d :e :c .")) + self.assertGraphEqual(implied, N3(":new :stmt :c .")) + + def testRuleMatchesStaticStatement(self): + inf = makeInferenceWithRules("{ :a :b ?x . :a :b :c . } => { :new :stmt ?x } .") + implied = inf.infer(N3(":a :b :c .")) + self.assertGraphEqual(implied, N3(":new :stmt :c .")) class TestInferenceWithMathFunctions(WithGraphEqual): - def test1(self): + def testBoolFilter(self): inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .") self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3("")) self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3("")) - self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt :a .")) + self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 .")) + + def testStatementGeneratingRule(self): + inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) + class TestInferenceWithCustomFunctions(WithGraphEqual): def testAsFarenheit(self): - inf = makeInferenceWithRules("{ :a :b ?x . ?x :asFarenheit ?f } => { :new :stmt ?f } .") - self.assertGraphEqual(inf.infer(N3(":a :b 0 .")), N3(":new :stmt -32 .")) + inf = makeInferenceWithRules("{ :a :b ?x . ?x room:asFarenheit ?f } => { :new :stmt ?f } .") + self.assertGraphEqual(inf.infer(N3(":a :b 12 .")), N3(":new :stmt 53.6 ."))