diff service/mqtt_to_rdf/inference.py @ 1651:20474ad4968e

WIP - functions are broken as i move most layers to work in Chunks not Triples A Chunk is a Triple plus any rdf lists.
author drewp@bigasterisk.com
date Sat, 18 Sep 2021 23:57:20 -0700
parents 2061df259224
children dddfa09ea0b9
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sat Sep 18 23:53:59 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sat Sep 18 23:57:20 2021 -0700
@@ -7,17 +7,18 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast)
+from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast
 
 from prometheus_client import Histogram, Summary
-from rdflib import RDF, BNode, Graph, Literal, Namespace
-from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
+from rdflib import RDF, BNode, Graph, Namespace
+from rdflib.graph import ConjunctiveGraph
 from rdflib.term import Node, URIRef, Variable
 
 from candidate_binding import BindingConflict, CandidateBinding
-from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple
-from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs, rulePredicates
+from inference_types import BindingUnknown, Inconsistent, Triple
+from lhs_evaluation import functionsFor
 from rdf_debug import graphDump
+from stmt_chunk import Chunk, ChunkedGraph, applyChunky
 
 log = logging.getLogger('infer')
 INDENT = '    '
@@ -30,61 +31,40 @@
 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
 
 
-def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]:
-    return (
-        None if isinstance(stmt[0], (Variable, BNode)) else stmt[0],
-        None if isinstance(stmt[1], (Variable, BNode)) else stmt[1],
-        None if isinstance(stmt[2], (Variable, BNode)) else stmt[2],
-    )
+class NoOptions(ValueError):
+    """ChunkLooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
 
 
-class NoOptions(ValueError):
-    """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
-
-
-class Inconsistent(ValueError):
-    """adding this stmt would be inconsistent with an existing binding"""
-
-
-_stmtLooperShortId = itertools.count()
+_chunkLooperShortId = itertools.count()
 
 
 @dataclass
-class StmtLooper:
-    """given one LHS stmt, iterate through the possible matches for it,
+class ChunkLooper:
+    """given one LHS Chunk, iterate through the possible matches for it,
     returning what bindings they would imply. Only distinct bindings are
-    returned. The bindings build on any `prev` StmtLooper's results.
+    returned. The bindings build on any `prev` ChunkLooper's results.
 
     This iterator is restartable."""
-    lhsStmt: Triple
-    prev: Optional['StmtLooper']
-    workingSet: ReadOnlyWorkingSet
+    lhsChunk: Chunk
+    prev: Optional['ChunkLooper']
+    workingSet: 'ChunkedGraph'
     parent: 'Lhs'  # just for lhs.graph, really
 
     def __repr__(self):
-        return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})'
+        return f'{self.__class__.__name__}{self._shortId}{"<pastEnd>" if self.pastEnd() else ""}'
 
     def __post_init__(self):
-        self._shortId = next(_stmtLooperShortId)
-        self._myWorkingSetMatches = self._myMatches(self.workingSet)
+        self._shortId = next(_chunkLooperShortId)
+        self._myWorkingSetMatches = self.lhsChunk.myMatches(self.workingSet)
 
         self._current = CandidateBinding({})
         self._pastEnd = False
         self._seenBindings: List[CandidateBinding] = []
 
-        log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})')
+        log.debug(f'{INDENT*6} introducing {self!r}({self.lhsChunk}, {self._myWorkingSetMatches=})')
 
         self.restart()
 
-    def _myMatches(self, g: Graph) -> List[Triple]:
-        template = stmtTemplate(self.lhsStmt)
-
-        stmts = sorted(cast(Iterator[Triple], list(g.triples(template))))
-        # plus new lhs possibilties...
-        # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}')
-
-        return stmts
-
     def _prevBindings(self) -> CandidateBinding:
         if not self.prev or self.prev.pastEnd():
             return CandidateBinding({})
@@ -96,12 +76,12 @@
         if self._pastEnd:
             raise NotImplementedError('need restart')
         log.debug('')
-        augmentedWorkingSet: Sequence[Triple] = []
+        augmentedWorkingSet: Sequence[Chunk] = []
         if self.prev is None:
             augmentedWorkingSet = self._myWorkingSetMatches
         else:
-            augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches,
-                                                                        returnBoundStatementsOnly=False))
+            augmentedWorkingSet = list(
+                applyChunky(self.prev.currentBinding(), self._myWorkingSetMatches, returnBoundStatementsOnly=False))
 
         log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}')
 
@@ -114,16 +94,16 @@
         log.debug(f'{INDENT*6} {self} is past end')
         self._pastEnd = True
 
-    def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool:
+    def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Chunk]) -> bool:
         log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
         for s in augmentedWorkingSet:
             log.debug(f'{INDENT*7} {s}')
 
-        for stmt in augmentedWorkingSet:
+        for chunk in augmentedWorkingSet:
             try:
-                outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
+                outBinding = self.lhsChunk.totalBindingIfThisStmtWereTrue(self._prevBindings(), chunk)
             except Inconsistent:
-                log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings')
+                log.debug(f'{INDENT*7} ChunkLooper{self._shortId} - {chunk} would be inconsistent with prev bindings')
                 continue
 
             log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
@@ -135,12 +115,12 @@
         return False
 
     def _advanceWithFunctions(self) -> bool:
-        pred: Node = self.lhsStmt[1]
+        pred: Node = self.lhsChunk.predicate
         if not isinstance(pred, URIRef):
             raise NotImplementedError
 
         for functionType in functionsFor(pred):
-            fn = functionType(self.lhsStmt, self.parent.graph)
+            fn = functionType(self.lhsChunk, self.parent.graph)
             try:
                 out = fn.bind(self._prevBindings())
             except BindingUnknown:
@@ -168,16 +148,6 @@
                 boundOperands.append(op)
         return boundOperands
 
-    def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
-        outBinding = self._prevBindings().copy()
-        for rt, ct in zip(self.lhsStmt, newStmt):
-            if isinstance(rt, (Variable, BNode)):
-                if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct:
-                    msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else ''
-                    raise Inconsistent(msg)
-                outBinding.addNewBindings(CandidateBinding({rt: ct}))
-        return outBinding
-
     def currentBinding(self) -> CandidateBinding:
         if self.pastEnd():
             raise NotImplementedError()
@@ -196,40 +166,19 @@
 
 @dataclass
 class Lhs:
-    graph: Graph  # our full LHS graph, as input. See below for the statements partitioned into groups.
+    graph: ChunkedGraph  # our full LHS graph, as input. See below for the statements partitioned into groups.
 
     def __post_init__(self):
 
-        usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
-
-        stmtsToMatch = list(self.graph - usedByFuncs)
-        self.staticStmts = []
-        self.patternStmts = []
-        for st in stmtsToMatch:
-            if all(isinstance(term, (URIRef, Literal)) for term in st):
-                self.staticStmts.append(st)
-            else:
-                self.patternStmts.append(st)
-
-        # sort them by variable dependencies; don't just try all perms!
-        def lightSortKey(stmt):  # Not this.
-            (s, p, o) = stmt
-            return p in rulePredicates(), p, s, o
-
-        self.patternStmts.sort(key=lightSortKey)
-
-        self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
-        self.myPreds -= rulePredicates()
-        self.myPreds -= {RDF.first, RDF.rest}
-        self.myPreds = set(self.myPreds)
+        self.myPreds = self.graph.allPredicatesExceptFunctions()
 
     def __repr__(self):
-        return f"Lhs({graphDump(self.graph)})"
+        return f"Lhs({self.graph!r})"
 
-    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
+    def findCandidateBindings(self, knownTrue: ChunkedGraph, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
-        if self.graph.__len__() == 0:
+        if not self.graph:
             # special case- no LHS!
             yield BoundLhs(self, CandidateBinding({}))
             return
@@ -238,26 +187,26 @@
             stats['_checkPredicateCountsCulls'] += 1
             return
 
-        if not all(st in knownTrue for st in self.staticStmts):
+        if not all(ch in knownTrue for ch in self.graph.staticChunks):
             stats['staticStmtCulls'] += 1
             return
 
-        if len(self.patternStmts) == 0:
+        if not self.graph.patternChunks:
             # static only
             yield BoundLhs(self, CandidateBinding({}))
             return
 
-        log.debug(f'{INDENT*4} build new StmtLooper stack')
+        log.debug(f'{INDENT*4} build new ChunkLooper stack')
 
         try:
-            stmtStack = self._assembleRings(knownTrue, stats)
+            chunkStack = self._assembleRings(knownTrue, stats)
         except NoOptions:
             log.debug(f'{INDENT*5} start up with no options; 0 bindings')
             return
-        self._debugStmtStack('initial odometer', stmtStack)
-        self._assertAllRingsAreValid(stmtStack)
+        self._debugChunkStack('initial odometer', chunkStack)
+        self._assertAllRingsAreValid(chunkStack)
 
-        lastRing = stmtStack[-1]
+        lastRing = chunkStack[-1]
         iterCount = 0
         while True:
             iterCount += 1
@@ -268,44 +217,45 @@
 
             yield BoundLhs(self, lastRing.currentBinding())
 
-            self._debugStmtStack('odometer', stmtStack)
+            self._debugChunkStack('odometer', chunkStack)
 
-            done = self._advanceAll(stmtStack)
+            done = self._advanceAll(chunkStack)
 
-            self._debugStmtStack('odometer after ({done=})', stmtStack)
+            self._debugChunkStack(f'odometer after ({done=})', chunkStack)
 
             log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
             if done:
                 break
 
-    def _debugStmtStack(self, label, stmtStack):
+    def _debugChunkStack(self, label: str, chunkStack: List[ChunkLooper]):
         log.debug(f'{INDENT*5} {label}:')
-        for l in stmtStack:
+        for l in chunkStack:
             log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
 
     def _checkPredicateCounts(self, knownTrue):
         """raise NoOptions quickly in some cases"""
 
-        if any((None, p, None) not in knownTrue for p in self.myPreds):
+        if self.graph.noPredicatesAppear(self.myPreds):
+            log.info(f'{INDENT*2} checkPredicateCounts does cull because not all {self.myPreds=} are in knownTrue')
             return True
         log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
         return False
 
-    def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]:
-        """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
+    def _assembleRings(self, knownTrue: ChunkedGraph, stats) -> List[ChunkLooper]:
+        """make ChunkLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
 
         log.info(f'{INDENT*2} stats={dict(stats)}')
-        log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}')
-        for i, perm in enumerate(itertools.permutations(self.patternStmts)):
-            stmtStack: List[StmtLooper] = []
-            prev: Optional[StmtLooper] = None
+        log.info(f'{INDENT*2} taking permutations of {len(self.graph.patternChunks)=}')
+        for i, perm in enumerate(itertools.permutations(self.graph.patternChunks)):
+            stmtStack: List[ChunkLooper] = []
+            prev: Optional[ChunkLooper] = None
             if log.isEnabledFor(logging.DEBUG):
-                log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
+                log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(repr(p) for p in perm)}')
 
             for s in perm:
                 try:
-                    elem = StmtLooper(s, prev, knownTrue, parent=self)
+                    elem = ChunkLooper(s, prev, knownTrue, parent=self)
                 except NoOptions:
                     log.debug(f'{INDENT*6} permutation didnt work, try another')
                     break
@@ -314,12 +264,12 @@
             else:
                 return stmtStack
             if i > 5000:
-                raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}')
+                raise NotImplementedError(f'trying too many permutations {len(self.graph.patternChunks)=}')
 
         log.debug(f'{INDENT*6} no perms worked- rule cannot match anything')
         raise NoOptions()
 
-    def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
+    def _advanceAll(self, stmtStack: List[ChunkLooper]) -> bool:
         carry = True  # 1st elem always must advance
         for i, ring in enumerate(stmtStack):
             # unlike normal odometer, advancing any earlier ring could invalidate later ones
@@ -354,12 +304,16 @@
     rhsGraph: Graph
 
     def __post_init__(self):
-        self.lhs = Lhs(self.lhsGraph)
+        self.lhs = Lhs(ChunkedGraph(self.lhsGraph, functionsFor))
         #
         self.rhsBnodeMap = {}
 
     def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit):
-        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit):
+        # this does not change for the current applyRule call. The rule will be
+        # tried again in an outer loop, in case it can produce more.
+        workingSetChunked = ChunkedGraph(workingSet, functionsFor)
+
+        for bound in self.lhs.findCandidateBindings(workingSetChunked, stats, ruleStatementsIterationLimit):
             log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
 
             # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do