diff service/mqtt_to_rdf/inference.py @ 1634:ba59cfc3c747

hack math:sum in there. Test suite is passing except some slow performers
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 23:48:43 -0700
parents 6107603ed455
children 22d481f0a924
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sun Sep 12 21:48:36 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sun Sep 12 23:48:43 2021 -0700
@@ -7,16 +7,16 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast
+from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast)
 
 from prometheus_client import Histogram, Summary
-from rdflib import BNode, Graph, Namespace
+from rdflib import RDF, BNode, Graph, Namespace
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
 from rdflib.term import Literal, Node, Variable
 
 from candidate_binding import CandidateBinding
-from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
-from lhs_evaluation import Decimal, Evaluation, numericNode
+from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple)
+from lhs_evaluation import Decimal, Evaluation, numericNode, parseList
 
 log = logging.getLogger('infer')
 INDENT = '    '
@@ -58,6 +58,7 @@
     lhsStmt: Triple
     prev: Optional['StmtLooper']
     workingSet: ReadOnlyWorkingSet
+    parent: 'Lhs'  # just for lhs.graph, really
 
     def __repr__(self):
         return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
@@ -98,9 +99,29 @@
             augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches,
                                                                         returnBoundStatementsOnly=False))
 
-        log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}')
+        log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}')
+
+        if self._advanceWithPlainMatches(augmentedWorkingSet):
+            return
+
+        if self._advanceWithBoolRules():
+            return
+
+        curBind = self.prev.currentBinding() if self.prev else CandidateBinding({})
+        [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False)
 
-        log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
+        fullWorkingSet = self.workingSet + self.parent.graph
+        boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False))
+        log.debug(f'{fullWorkingSet.__len__()=} {len(boundFullWorkingSet)=}')
+
+        if self._advanceWithFunctions(augmentedWorkingSet, boundFullWorkingSet, lhsStmtBound):
+            return
+
+        log.debug(f'{INDENT*6} {self} is past end')
+        self._pastEnd = True
+
+    def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool:
+        log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
         for s in augmentedWorkingSet:
             log.debug(f'{INDENT*7} {s}')
 
@@ -111,19 +132,38 @@
                 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
                 continue
 
-            log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}')
+            log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
             if outBinding.binding not in self._seenBindings:
                 self._seenBindings.append(outBinding.binding.copy())
                 self._current = outBinding
                 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}')
-                return
-            log.debug(f'yes we saw')
+                return True
+        return False
 
-        log.debug(f'{INDENT*6} {self} mines rules')
+    def _advanceWithBoolRules(self) -> bool:
+        log.debug(f'{INDENT*7} {self} mines bool rules')
+        if self.lhsStmt[1] == MATH['greaterThan']:
+            operands = [self.lhsStmt[0], self.lhsStmt[2]]
+            try:
+                boundOperands = self._boundOperands(operands)
+            except BindingUnknown:
+                return False
+            if numericNode(boundOperands[0]) > numericNode(boundOperands[1]):
+                bindingDict: Dict[BindableTerm,
+                                  Node] = self._prevBindings().copy()  # no new values; just allow matching to keep going
+                if bindingDict not in self._seenBindings:
+                    self._seenBindings.append(bindingDict)
+                    self._current = CandidateBinding(bindingDict)
+                    log.debug(f'{INDENT*7} new binding from {self} -> {bindingDict}')
+                    return True
+        return False
+
+    def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool:
+        log.debug(f'{INDENT*7} {self} mines rules')
 
         if self.lhsStmt[1] == ROOM['asFarenheit']:
             pb: Dict[BindableTerm, Node] = self._prevBindings()
-            log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
+            log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
 
             if self.lhsStmt[0] in pb:
                 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]]
@@ -136,17 +176,59 @@
                 if newBindings not in self._seenBindings:
                     self._seenBindings.append(newBindings)
                     self._current = CandidateBinding(newBindings)
-                    return
+                    return True
+        elif self.lhsStmt[1] == MATH['sum']:
+
+            g = Graph()
+            for s in boundFullWorkingSet:
+                g.add(s)
+                log.debug(f' boundWorkingSet graph: {s}')
+            log.debug(f'_parseList subj = {lhsStmtBound[0]}')
+            operands, _ = parseList(g, lhsStmtBound[0])
+            log.debug(f'********* {INDENT*7} {self} found list {operands=}')
+            try:
+                obj = Literal(sum(map(numericNode, operands)))
+            except TypeError:
+                log.debug('typeerr in operands')
+                pass
+            else:
+                objVar = lhsStmtBound[2]
+                log.debug(f'{objVar=}')
 
-        log.debug(f'{INDENT*6} {self} is past end')
-        self._pastEnd = True
+                if not isinstance(objVar, Variable):
+                    raise TypeError(f'expected Variable, got {objVar!r}')
+                newBindings: Dict[BindableTerm, Node] = {objVar: obj}
+                log.debug(f'{newBindings=}')
+
+                self._current.addNewBindings(CandidateBinding(newBindings))
+                log.debug(f'{self._seenBindings=}')
+                if newBindings not in self._seenBindings:
+                    self._seenBindings.append(newBindings)
+                    self._current = CandidateBinding(newBindings)
+                    return True
+
+        return False
+
+    def _boundOperands(self, operands) -> List[Node]:
+        pb: Dict[BindableTerm, Node] = self._prevBindings()
+
+        boundOperands: List[Node] = []
+        for op in operands:
+            if isinstance(op, (Variable, BNode)):
+                if op in pb:
+                    boundOperands.append(pb[op])
+                else:
+                    raise BindingUnknown()
+            else:
+                boundOperands.append(op)
+        return boundOperands
 
     def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
         outBinding = self._prevBindings().copy()
         for rt, ct in zip(self.lhsStmt, newStmt):
             if isinstance(rt, (Variable, BNode)):
                 if rt in outBinding and outBinding[rt] != ct:
-                    raise Inconsistent()
+                    raise Inconsistent(f'{rt=} {ct=} {outBinding=}')
                 outBinding[rt] = ct
         return CandidateBinding(outBinding)
 
@@ -245,7 +327,21 @@
         """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
 
-        stmtsToAdd = list(self.graph)
+        usedByFuncs: Set[Triple] = set()  # don't worry about matching these
+        stmtsToResolve = list(self.graph)
+        for i, s in enumerate(stmtsToResolve):
+            if s[1] == MATH['sum']:
+                _, used = parseList(self.graph, s[0])
+                usedByFuncs.update(used)
+
+        stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs]
+
+        # sort them by variable dependencies; don't just try all perms!
+        def lightSortKey(stmt):  # Not this. Though it helps performance on the big rdf list cases.
+            (s, p, o) = stmt
+            return p == MATH['sum'], p, s, o
+
+        stmtsToAdd.sort(key=lightSortKey)
 
         for perm in itertools.permutations(stmtsToAdd):
             stmtStack: List[StmtLooper] = []
@@ -254,7 +350,7 @@
 
             for s in perm:
                 try:
-                    elem = StmtLooper(s, prev, knownTrue)
+                    elem = StmtLooper(s, prev, knownTrue, parent=self)
                 except NoOptions:
                     log.debug(f'{INDENT*6} permutation didnt work, try another')
                     break
@@ -540,7 +636,7 @@
         log.debug('')
         log.debug(f'{INDENT*2}-applying rule {i}')
         log.debug(f'{INDENT*3} rule def lhs:')
-        for stmt in r.lhsGraph:
+        for stmt in sorted(r.lhsGraph, reverse=True):
             log.debug(f'{INDENT*4} {stmt}')
         log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')