diff service/mqtt_to_rdf/inference.py @ 1633:6107603ed455

fix farenheit rule case, fix some others that depend on rings order, but this breaks some performance because of itertools.perm
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 21:48:36 -0700
parents bd79a2941cab
children ba59cfc3c747
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sun Sep 12 21:46:39 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sun Sep 12 21:48:36 2021 -0700
@@ -7,7 +7,7 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast
+from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast
 
 from prometheus_client import Histogram, Summary
 from rdflib import BNode, Graph, Namespace
@@ -88,8 +88,23 @@
 
     def advance(self):
         """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode"""
-        log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements')
-        for i, stmt in enumerate(self._myWorkingSetMatches):
+        if self._pastEnd:
+            raise NotImplementedError('need restart')
+        log.debug('')
+        augmentedWorkingSet: Sequence[Triple] = []
+        if self.prev is None:
+            augmentedWorkingSet = self._myWorkingSetMatches
+        else:
+            augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches,
+                                                                        returnBoundStatementsOnly=False))
+
+        log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}')
+
+        log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
+        for s in augmentedWorkingSet:
+            log.debug(f'{INDENT*7} {s}')
+
+        for i, stmt in enumerate(augmentedWorkingSet):
             try:
                 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
             except Inconsistent:
@@ -121,6 +136,7 @@
                 if newBindings not in self._seenBindings:
                     self._seenBindings.append(newBindings)
                     self._current = CandidateBinding(newBindings)
+                    return
 
         log.debug(f'{INDENT*6} {self} is past end')
         self._pastEnd = True
@@ -184,25 +200,22 @@
     def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
+        if self.graph.__len__() == 0:
+            # special case- no LHS!
+            yield BoundLhs(self, CandidateBinding({}))
+            return
+
         log.debug(f'{INDENT*4} build new StmtLooper stack')
 
-        stmtStack: List[StmtLooper] = []
         try:
-            prev: Optional[StmtLooper] = None
-            for s in sorted(self.graph):  # order of this matters! :(
-                stmtStack.append(StmtLooper(s, prev, knownTrue))
-                prev = stmtStack[-1]
+            stmtStack = self._assembleRings(knownTrue)
         except NoOptions:
             log.debug(f'{INDENT*5} start up with no options; 0 bindings')
             return
         self._debugStmtStack('initial odometer', stmtStack)
-
+        self._assertAllRingsAreValid(stmtStack)
 
-        if any(ring.pastEnd() for ring in stmtStack):
-            log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}')
-
-            raise NoOptions()
-        sl = stmtStack[-1]
+        lastRing = stmtStack[-1]
         iterCount = 0
         while True:
             iterCount += 1
@@ -211,7 +224,7 @@
 
             log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
 
-            yield BoundLhs(self, sl.currentBinding())
+            yield BoundLhs(self, lastRing.currentBinding())
 
             self._debugStmtStack('odometer', stmtStack)
 
@@ -228,6 +241,31 @@
         for l in stmtStack:
             log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
 
+    def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]:
+        """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
+        start out valid (or else raise NoOptions)"""
+
+        stmtsToAdd = list(self.graph)
+
+        for perm in itertools.permutations(stmtsToAdd):
+            stmtStack: List[StmtLooper] = []
+            prev: Optional[StmtLooper] = None
+            log.debug(f'{INDENT*5} try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
+
+            for s in perm:
+                try:
+                    elem = StmtLooper(s, prev, knownTrue)
+                except NoOptions:
+                    log.debug(f'{INDENT*6} permutation didnt work, try another')
+                    break
+                stmtStack.append(elem)
+                prev = stmtStack[-1]
+            else:
+                return stmtStack
+        log.debug(f'{INDENT*6} no perms worked- rule cannot match anything')
+
+        raise NoOptions()
+
     def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
         carry = True  # 1st elem always must advance
         for i, ring in enumerate(stmtStack):
@@ -245,6 +283,11 @@
                 carry = True
         return False
 
+    def _assertAllRingsAreValid(self, stmtStack):
+        if any(ring.pastEnd() for ring in stmtStack):  # this is an unexpected debug assertion
+            log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}')
+            raise NoOptions()
+
     def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool:
         # bug: see TestSelfFulfillingRule.test3 for a case where this rule's
         # static stmt is matched by a non-static stmt in the rule itself