changeset 1607:b21885181e35

more modules, types. Maybe less repeated computation on BoundLhs
author drewp@bigasterisk.com
date Mon, 06 Sep 2021 15:38:48 -0700
parents 6cf39d43fd40
children f928eb06a4f6
files service/mqtt_to_rdf/candidate_binding.py service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_types.py service/mqtt_to_rdf/lhs_evaluation.py
diffstat 4 files changed, 204 insertions(+), 176 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/service/mqtt_to_rdf/candidate_binding.py	Mon Sep 06 15:38:48 2021 -0700
@@ -0,0 +1,40 @@
+from dataclasses import dataclass
+from typing import Dict, Iterator
+
+from prometheus_client import Summary
+from rdflib import BNode, Graph
+from rdflib.term import Node, Variable
+
+from inference_types import BindableTerm, BindingUnknown, Triple
+
+
+@dataclass
+class CandidateBinding:
+    binding: Dict[BindableTerm, Node]
+
+    def __repr__(self):
+        b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items()))
+        return f'CandidateBinding({b})'
+
+    def apply(self, g: Graph) -> Iterator[Triple]:
+        for stmt in g:
+            try:
+                bound = (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2]))
+            except BindingUnknown:
+                continue
+            yield bound
+
+    def _applyTerm(self, term: Node):
+        if isinstance(term, (Variable, BNode)):
+            if term in self.binding:
+                return self.binding[term]
+            else:
+                raise BindingUnknown()
+        return term
+
+    def addNewBindings(self, newBindings: 'CandidateBinding'):
+        for k, v in newBindings.binding.items():
+            if k in self.binding and self.binding[k] != v:
+                raise ValueError(
+                    f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
+            self.binding[k] = v
\ No newline at end of file
--- a/service/mqtt_to_rdf/inference.py	Mon Sep 06 01:15:14 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Mon Sep 06 15:38:48 2021 -0700
@@ -7,24 +7,20 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass, field
-from decimal import Decimal
-from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union, cast
+from typing import Dict, Iterator, List, Set, Tuple, Union, cast
 
 from prometheus_client import Summary
-from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef
+from rdflib import BNode, Graph, Namespace, URIRef
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
 from rdflib.term import Node, Variable
 
-from lhs_evaluation import EvaluationFailed, Evaluation
+from candidate_binding import CandidateBinding
+from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
+from lhs_evaluation import Evaluation
 
 log = logging.getLogger('infer')
 INDENT = '    '
 
-Triple = Tuple[Node, Node, Node]
-Rule = Tuple[Graph, Node, Graph]
-BindableTerm = Union[Variable, BNode]
-ReadOnlyWorkingSet = ReadOnlyGraphAggregate
-
 INFER_CALLS = Summary('read_rules_calls', 'calls')
 
 ROOM = Namespace("http://projects.bigasterisk.com/room/")
@@ -36,110 +32,9 @@
 GRAPH_ID = URIRef('dont/care')
 
 
-class BindingUnknown(ValueError):
-    """e.g. we were asked to make the bound version 
-    of (A B ?c) and we don't have a binding for ?c
-    """
-
-
-@dataclass
-class CandidateBinding:
-    binding: Dict[BindableTerm, Node]
-
-    def __repr__(self):
-        b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items()))
-        return f'CandidateBinding({b})'
-
-    def apply(self, g: Graph) -> Iterator[Triple]:
-        for stmt in g:
-            try:
-                bound = (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2]))
-            except BindingUnknown:
-                continue
-            yield bound
-
-    def _applyTerm(self, term: Node):
-        if isinstance(term, (Variable, BNode)):
-            if term in self.binding:
-                return self.binding[term]
-            else:
-                raise BindingUnknown()
-        return term
-
-    def applyFunctions(self, lhs) -> Graph:
-        """may grow the binding with some results"""
-        usedByFuncs = Graph(identifier=GRAPH_ID)
-        while True:
-            delta = self._applyFunctionsIteration(lhs, usedByFuncs)
-            if delta == 0:
-                break
-        return usedByFuncs
-
-    def _applyFunctionsIteration(self, lhs, usedByFuncs: Graph):
-        before = len(self.binding)
-        delta = 0
-        for ev in lhs.evaluations:
-            log.debug(f'{INDENT*3} found Evaluation')
-
-            newBindings, usedGraph = ev.resultBindings(self.binding)
-            usedByFuncs += usedGraph
-            self._addNewBindings(newBindings)
-            delta = len(self.binding) - before
-            if log.isEnabledFor(logging.DEBUG):
-                dump = "(...)"
-                if cast(int, usedGraph.__len__()) < 20:
-                    dump = graphDump(usedGraph)
-                log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings')
-        return delta
-
-    def _addNewBindings(self, newBindings):
-        for k, v in newBindings.items():
-            if k in self.binding and self.binding[k] != v:
-                raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
-            self.binding[k] = v
-
-    def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool:
-        """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
-        boundLhs = list(self.apply(lhs.graph))
-        boundUsedByFuncs = list(self.apply(usedByFuncs))
-
-        self._logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs)
-
-        for stmt in boundLhs:
-            log.debug(f'{INDENT*4} check for {stmt}')
-
-            if stmt in boundUsedByFuncs:
-                pass
-            elif stmt in workingSet:
-                pass
-            else:
-                log.debug(f'{INDENT*5} stmt not known to be true')
-                return False
-        return True
-
-    def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs):
-        if not log.isEnabledFor(logging.DEBUG):
-            return
-        log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:')
-        for stmt in sorted(boundLhs):
-            log.debug(f'{INDENT*4}|{INDENT} {stmt}')
-
-        # log.debug(f'{INDENT*4}| and against this workingSet:')
-        # for stmt in sorted(workingSet):
-        #     log.debug(f'{INDENT*4}|{INDENT} {stmt}')
-
-        stmts = sorted(boundUsedByFuncs)
-        if stmts:
-            log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:')
-            for stmt in stmts:
-                log.debug(f'{INDENT*4}|{INDENT} {stmt}')
-        log.debug(f'{INDENT*4}\\')
-
-
 @dataclass
 class Lhs:
     graph: Graph
-    stats: Dict
 
     staticRuleStmts: Graph = field(default_factory=Graph)
     lhsBindables: Set[BindableTerm] = field(default_factory=set)
@@ -155,42 +50,41 @@
 
         self.evaluations = list(Evaluation.findEvals(self.graph))
 
-    def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]:
+    def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
         log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}')
-        self.stats['findCandidateBindingsCalls'] += 1
+        stats['findCandidateBindingsCalls'] += 1
 
         if not self._allStaticStatementsMatch(workingSet):
-            self.stats['findCandidateBindingEarlyExits'] += 1
+            stats['findCandidateBindingEarlyExits'] += 1
             return
 
+        for binding in self._possibleBindings(workingSet, stats):
+            log.debug('')
+            log.debug(f'{INDENT*4}*trying {binding.binding}')
+
+            if not binding.verify(workingSet):
+                log.debug(f'{INDENT*4} this binding did not verify')
+                stats['permCountFailingVerify'] += 1
+                continue
+
+            stats['permCountSucceeding'] += 1
+            yield binding
+
+    def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']:
+        """this yields at least the working bindings, and possibly others"""
         candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet)
 
         orderedVars, orderedValueSets = _organize(candidateTermMatches)
-
         self._logCandidates(orderedVars, orderedValueSets)
 
         log.debug(f'{INDENT*3} trying all permutations:')
-
         for perm in itertools.product(*orderedValueSets):
-            binding = CandidateBinding(dict(zip(orderedVars, perm)))
-            log.debug('')
-            log.debug(f'{INDENT*4}*trying {binding}')
-
             try:
-                usedByFuncs = binding.applyFunctions(self)
+                yield BoundLhs(self, CandidateBinding(dict(zip(orderedVars, perm))))
             except EvaluationFailed:
-                self.stats['permCountFailingEval'] += 1
-                continue
-
-            if not binding.verify(self, workingSet, usedByFuncs):
-                log.debug(f'{INDENT*4} this binding did not verify')
-                self.stats['permCountFailingVerify'] += 1
-                continue
-
-            self.stats['permCountSucceeding'] += 1
-            yield binding
+                stats['permCountFailingEval'] += 1
 
     def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool:
         for ruleStmt in self.staticRuleStmts:
@@ -236,15 +130,6 @@
                 log.debug(f'{INDENT*5} {v=} {vals=}')
                 yield v, vals
 
-    def graphWithoutEvals(self, binding: CandidateBinding) -> Graph:
-        g = Graph(identifier=GRAPH_ID)
-        usedByFuncs = binding.applyFunctions(self)
-
-        for stmt in self.graph:
-            if stmt not in usedByFuncs:
-                g.add(stmt)
-        return g
-
     def _logCandidates(self, orderedVars, orderedValueSets):
         if not log.isEnabledFor(logging.DEBUG):
             return
@@ -255,16 +140,106 @@
                 log.debug(f'{INDENT*5}{val!r}')
 
 
+@dataclass
+class BoundLhs:
+    lhs: Lhs
+    binding: CandidateBinding
+
+    def __post_init__(self):
+        self.usedByFuncs = Graph(identifier=GRAPH_ID)
+        self.graphWithoutEvals = self._graphWithoutEvals()
+
+    def _graphWithoutEvals(self) -> Graph:
+        g = Graph(identifier=GRAPH_ID)
+        self._applyFunctions()
+
+        for stmt in self.lhs.graph:
+            if stmt not in self.usedByFuncs:
+                g.add(stmt)
+        return g
+
+    def _applyFunctions(self):
+        """may grow the binding with some results"""
+        while True:
+            delta = self._applyFunctionsIteration()
+            if delta == 0:
+                break
+
+    def _applyFunctionsIteration(self):
+        before = len(self.binding.binding)
+        delta = 0
+        for ev in self.lhs.evaluations:
+            log.debug(f'{INDENT*3} found Evaluation')
+
+            newBindings, usedGraph = ev.resultBindings(self.binding)
+            self.usedByFuncs += usedGraph
+            self.binding.addNewBindings(newBindings)
+            delta = len(self.binding.binding) - before
+            if log.isEnabledFor(logging.DEBUG):
+                dump = "(...)"
+                if cast(int, usedGraph.__len__()) < 20:
+                    dump = graphDump(usedGraph)
+                log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings')
+        return delta
+
+
+    def verify(self, workingSet: ReadOnlyWorkingSet) -> bool:
+        """Can this bound lhs be true all at once in workingSet?"""
+        boundLhs = list(self.binding.apply(self.lhs.graph))
+        boundUsedByFuncs = list(self.binding.apply(self.usedByFuncs))
+
+        self._logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs)
+
+        for stmt in boundLhs:
+            log.debug(f'{INDENT*4} check for {stmt}')
+
+            if stmt in boundUsedByFuncs:
+                pass
+            elif stmt in workingSet:
+                pass
+            else:
+                log.debug(f'{INDENT*5} stmt not known to be true')
+                return False
+        return True
+
+    def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs):
+        if not log.isEnabledFor(logging.DEBUG):
+            return
+        log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:')
+        for stmt in sorted(boundLhs):
+            log.debug(f'{INDENT*4}|{INDENT} {stmt}')
+
+        # log.debug(f'{INDENT*4}| and against this workingSet:')
+        # for stmt in sorted(workingSet):
+        #     log.debug(f'{INDENT*4}|{INDENT} {stmt}')
+
+        stmts = sorted(boundUsedByFuncs)
+        if stmts:
+            log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:')
+            for stmt in stmts:
+                log.debug(f'{INDENT*4}|{INDENT} {stmt}')
+        log.debug(f'{INDENT*4}\\')
+
+
+@dataclass
+class Rule:
+    lhsGraph: Graph
+    rhsGraph: Graph
+
+    def __post_init__(self):
+        self.lhs = Lhs(self.lhsGraph)
+
+
 class Inference:
 
     def __init__(self) -> None:
-        self.rules = ConjunctiveGraph()
+        self.rules = []
 
     def setRules(self, g: ConjunctiveGraph):
-        self.rules = ConjunctiveGraph()
+        self.rules: List[Rule] = []
         for stmt in g:
             if stmt[1] == LOG['implies']:
-                self.rules.add(stmt)
+                self.rules.append(Rule(stmt[0], stmt[2]))
             # others should go to a default working set?
 
     @INFER_CALLS.time()
@@ -274,7 +249,7 @@
         """
         log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:')
         startTime = time.time()
-        self.stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
+        stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
         # everything that is true: the input graph, plus every rule conclusion we can make
         workingSet = Graph()
         workingSet += graph
@@ -284,28 +259,28 @@
 
         bailout_iterations = 100
         delta = 1
-        self.stats['initWorkingSet'] = cast(int, workingSet.__len__())
+        stats['initWorkingSet'] = cast(int, workingSet.__len__())
         while delta > 0 and bailout_iterations > 0:
             log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
             bailout_iterations -= 1
             delta = -len(implied)
-            self._iterateAllRules(workingSet, implied)
+            self._iterateAllRules(workingSet, implied, stats)
             delta += len(implied)
-            self.stats['iterations'] += 1
+            stats['iterations'] += 1
             log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
-        self.stats['timeSpent'] = round(time.time() - startTime, 3)
-        self.stats['impliedStmts'] = len(implied)
-        log.info(f'{INDENT*0} Inference done {dict(self.stats)}. Implied:')
+        stats['timeSpent'] = round(time.time() - startTime, 3)
+        stats['impliedStmts'] = len(implied)
+        log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:')
         for st in implied:
             log.info(f'{INDENT*1} {st}')
         return implied
 
-    def _iterateAllRules(self, workingSet: Graph, implied: Graph):
+    def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
         for i, r in enumerate(self.rules):
             self._logRuleApplicationHeader(workingSet, i, r)
-            _applyRule(Lhs(r[0], self.stats), r[2], workingSet, implied, self.stats)
+            _applyRule(r.lhs, r.rhsGraph, workingSet, implied, stats)
 
-    def _logRuleApplicationHeader(self, workingSet, i, r):
+    def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
         if not log.isEnabledFor(logging.DEBUG):
             return
 
@@ -316,18 +291,18 @@
 
         log.debug('')
         log.debug(f'{INDENT*2}-applying rule {i}')
-        log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}')
-        log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}')
+        log.debug(f'{INDENT*3} rule def lhs: {graphDump(r.lhsGraph)}')
+        log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
 
 
 def _applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph, stats: Dict):
-    for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])):
+    for bound in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
         log.debug(f'{INDENT*3} rule has a working binding:')
 
-        for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)):
+        for lhsBoundStmt in bound.binding.apply(bound.graphWithoutEvals):
             log.debug(f'{INDENT*5} adding {lhsBoundStmt=}')
             workingSet.add(lhsBoundStmt)
-        for newStmt in binding.apply(rhs):
+        for newStmt in bound.binding.apply(rhs):
             log.debug(f'{INDENT*5} adding {newStmt=}')
             workingSet.add(newStmt)
             implied.add(newStmt)
@@ -335,6 +310,7 @@
 
 def graphDump(g: Union[Graph, List[Triple]]):
     if not isinstance(g, Graph):
+        log.warning(f"it's a {type(g)}")
         g2 = Graph()
         g2 += g
         g = g2
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/service/mqtt_to_rdf/inference_types.py	Mon Sep 06 15:38:48 2021 -0700
@@ -0,0 +1,18 @@
+from typing import Tuple, Union
+from rdflib import Graph
+from rdflib.term import Node, BNode, Variable
+from rdflib.graph import ReadOnlyGraphAggregate
+
+BindableTerm = Union[Variable, BNode]
+ReadOnlyWorkingSet = ReadOnlyGraphAggregate
+Triple = Tuple[Node, Node, Node]
+
+
+class EvaluationFailed(ValueError):
+    """e.g. we were given (5 math:greaterThan 6)"""
+
+
+class BindingUnknown(ValueError):
+    """e.g. we were asked to make the bound version 
+    of (A B ?c) and we don't have a binding for ?c
+    """
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 06 01:15:14 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 06 15:38:48 2021 -0700
@@ -1,22 +1,19 @@
 import logging
-from dataclasses import dataclass, field
 from decimal import Decimal
-from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union, cast
+from typing import Dict, Iterable, Iterator, List, Set, Tuple
 
 from prometheus_client import Summary
-from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef
-from rdflib.graph import ReadOnlyGraphAggregate
+from rdflib import RDF, Graph, Literal, Namespace, URIRef
 from rdflib.term import Node, Variable
 
+from candidate_binding import CandidateBinding
+from inference import CandidateBinding
+from inference_types import BindableTerm, EvaluationFailed, Triple
+
 log = logging.getLogger('infer')
 
 INDENT = '    '
 
-Triple = Tuple[Node, Node, Node]
-Rule = Tuple[Graph, Node, Graph]
-BindableTerm = Union[Variable, BNode]
-ReadOnlyWorkingSet = ReadOnlyGraphAggregate
-
 ROOM = Namespace("http://projects.bigasterisk.com/room/")
 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
@@ -26,10 +23,7 @@
 GRAPH_ID = URIRef('dont/care')
 
 
-class EvaluationFailed(ValueError):
-    """e.g. we were given (5 math:greaterThan 6)"""
-
-
+# alternate name LhsComponent
 class Evaluation:
     """some lhs statements need to be evaluated with a special function 
     (e.g. math) and then not considered for the rest of the rule-firing 
@@ -59,7 +53,7 @@
         self.operandsStmts.add(mainStmt)
         self.stmt = mainStmt
 
-    def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]:
+    def resultBindings(self, inputBindings: CandidateBinding) -> Tuple[CandidateBinding, Graph]:
         """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?"""
         pred = self.stmt[1]
         objVar: Node = self.stmt[2]
@@ -67,9 +61,9 @@
         for op in self.operands:
             if isinstance(op, Variable):
                 try:
-                    op = inputBindings[op]
+                    op = inputBindings.binding[op]
                 except KeyError:
-                    return {}, self.operandsStmts
+                    return CandidateBinding(binding={}), self.operandsStmts
 
             boundOperands.append(op)
 
@@ -77,18 +71,18 @@
             obj = Literal(sum(map(numericNode, boundOperands)))
             if not isinstance(objVar, Variable):
                 raise TypeError(f'expected Variable, got {objVar!r}')
-            res: Dict[BindableTerm, Node] = {objVar: obj}
+            res = CandidateBinding({objVar: obj})
         elif pred == ROOM['asFarenheit']:
             if len(boundOperands) != 1:
                 raise ValueError(":asFarenheit takes 1 subject operand")
             f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32)
             if not isinstance(objVar, Variable):
                 raise TypeError(f'expected Variable, got {objVar!r}')
-            res: Dict[BindableTerm, Node] = {objVar: f}
+            res = CandidateBinding({objVar: f})
         elif pred == MATH['greaterThan']:
             if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])):
                 raise EvaluationFailed()
-            res: Dict[BindableTerm, Node] = {}
+            res= CandidateBinding({})
         else:
             raise NotImplementedError(repr(pred))