diff service/mqtt_to_rdf/inference.py @ 1588:0757fafbfdab

WIP inferencer - partial var and function support
author drewp@bigasterisk.com
date Thu, 02 Sep 2021 01:58:31 -0700
parents 9a3a18c494f9
children 5c1055be3c36
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sun Aug 29 23:59:09 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Thu Sep 02 01:58:31 2021 -0700
@@ -2,12 +2,16 @@
 copied from reasoning 2021-08-29. probably same api. should
 be able to lib/ this out
 """
-
+import itertools
 import logging
-from typing import Dict, Tuple
 from dataclasses import dataclass
+from decimal import Decimal
+from typing import Dict, Iterator, List, Set, Tuple, cast
+from urllib.request import OpenerDirector
+
 from prometheus_client import Summary
-from rdflib import Graph, Namespace
+from rdflib import BNode, Graph, Literal, Namespace
+from rdflib.collection import Collection
 from rdflib.graph import ConjunctiveGraph
 from rdflib.term import Node, Variable
 
@@ -57,7 +61,7 @@
         while delta > 0 and bailout_iterations > 0:
             bailout_iterations -= 1
             delta = -len(implied)
-            self._iterateRules(workingSet, implied)
+            self._iterateAllRules(workingSet, implied)
             delta += len(implied)
             log.info(f'  this inference round added {delta} more implied stmts')
         log.info(f'{len(implied)} stmts implied:')
@@ -65,19 +69,161 @@
             log.info(f'  {st}')
         return implied
 
-    def _iterateRules(self, workingSet, implied):
+    def _iterateAllRules(self, workingSet, implied):
         for r in self.rules:
             if r[1] == LOG['implies']:
-                self._applyRule(r[0], r[2], workingSet, implied)
+                applyRule(r[0], r[2], workingSet, implied)
             else:
                 log.info(f'  {r} not a rule?')
 
-    def _applyRule(self, lhs, rhs, workingSet, implied):
-        containsSetup = self._containsSetup(lhs, workingSet)
-        if containsSetup:
-            for st in rhs:
-                workingSet.add(st)
-                implied.add(st)
+
+def applyRule(lhs: Graph, rhs: Graph, workingSet, implied):
+    for bindings in findCandidateBindings(lhs, workingSet):
+        log.debug(f' - rule gave {bindings=}')
+        for newStmt in withBinding(rhs, bindings):
+            workingSet.add(newStmt)
+            implied.add(newStmt)
+
+
+def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]:
+    varsToBind: Set[Variable] = set()
+    staticRuleStmts = []
+    for ruleStmt in lhs:
+        varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)]
+        varsToBind.update(varsInStmt)
+        if (not varsInStmt  # ok
+                and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
+           ):
+            staticRuleStmts.append(ruleStmt)
+
+    if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
+        log.debug('static shortcircuit')
+        return
+
+    # the total set of terms each variable could possibly match
+    candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
+
+    orderedVars, orderedValueSets = organize(candidateTermMatches)
+
+    log.debug(f'     {orderedVars=}')
+    log.debug(f'{orderedValueSets=}')
+
+    for perm in itertools.product(*orderedValueSets):
+        binding: Dict[Variable, Node] = dict(zip(orderedVars, perm))
+        log.debug(f'{binding=} but lets look for funcs')
+        for v, val in inferredFuncBindings(lhs, binding):  # loop this until it's done
+            log.debug(f'ifb tells us {v}={val}')
+            binding[v] = val
+        if not verifyBinding(lhs, binding, workingSet):  # fix this
+            log.debug(f'verify culls')
+            continue
+        yield binding
+
+
+def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]:
+    for stmt in lhs:
+        if stmt[1] not in inferredFuncs:
+            continue
+        if not isinstance(stmt[2], Variable):
+            continue
+
+        x = stmt[0]
+        if isinstance(x, Variable):
+            x = bindingsBefore[x]
+        yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
+
+
+def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]:
+    candidateTermMatches: Dict[Variable, Set[Node]] = {}
+
+    for r in lhs:
+        for w in workingSet:
+            bindingsFromStatement: Dict[Variable, Set[Node]] = {}
+            for rterm, wterm in zip(r, w):
+                if isinstance(rterm, Variable):
+                    bindingsFromStatement.setdefault(rterm, set()).add(wterm)
+                elif rterm != wterm:
+                    break
+            else:
+                for v, vals in bindingsFromStatement.items():
+                    candidateTermMatches.setdefault(v, set()).update(vals)
+    return candidateTermMatches
+
 
-    def _containsSetup(self, lhs, workingSet):
-        return all(st in workingSet for st in lhs)
+def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]:
+    for stmt in rhs:
+        stmt = list(stmt)
+        for i, t in enumerate(stmt):
+            if isinstance(t, Variable):
+                try:
+                    stmt[i] = bindings[t]
+                except KeyError:
+                    # stmt is from another rule that we're not applying right now
+                    break
+        else:
+            yield cast(Triple, stmt)
+
+
+def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool:
+    for stmt in withBinding(lhs, binding):
+        log.debug(f'lhs verify {stmt}')
+        if stmt[1] in filterFuncs:
+            if not mathTest(*stmt):
+                return False
+        elif stmt not in workingSet and stmt[1] not in inferredFuncs:
+            log.debug(f'  ver culls here')
+            return False
+    return True
+
+
+inferredFuncs = {
+    ROOM['asFarenheit'],
+    MATH['sum'],
+}
+filterFuncs = {
+    MATH['greaterThan'],
+}
+
+
+def inferredFuncObject(subj, pred, graph, bindings):
+    if pred == ROOM['asFarenheit']:
+        return Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
+    elif pred == MATH['sum']:
+        operands = Collection(graph, subj)
+        # shouldn't be redoing this here
+        operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
+        log.debug(f' sum {list(operands)}')
+        return Literal(sum(op.toPython() for op in operands))
+
+    else:
+        raise NotImplementedError(pred)
+
+
+def mathTest(subj, pred, obj):
+    x = subj.toPython()
+    y = obj.toPython()
+    if pred == MATH['greaterThan']:
+        return x > y
+    else:
+        raise NotImplementedError(pred)
+
+
+def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]:
+    items = list(candidateTermMatches.items())
+    items.sort()
+    orderedVars: List[Variable] = []
+    orderedValueSets: List[List[Node]] = []
+    for v, vals in items:
+        orderedVars.append(v)
+        orderedValues: List[Node] = list(vals)
+        orderedValues.sort(key=str)
+        orderedValueSets.append(orderedValues)
+
+    return orderedVars, orderedValueSets
+
+
+def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
+    for ruleStmt in staticRuleStmts:
+        if ruleStmt not in workingSet:
+            return True
+    return False