diff service/mqtt_to_rdf/inference.py @ 1601:30463df12d89

infer() dumps stats
author drewp@bigasterisk.com
date Sun, 05 Sep 2021 23:27:49 -0700
parents 89a50242cb5e
children e3c44ac6d3c5
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Sun Sep 05 22:50:15 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sun Sep 05 23:27:49 2021 -0700
@@ -4,6 +4,7 @@
 """
 import itertools
 import logging
+import time
 from collections import defaultdict
 from dataclasses import dataclass, field
 from decimal import Decimal
@@ -22,12 +23,15 @@
 BindableTerm = Union[Variable, BNode]
 ReadOnlyWorkingSet = ReadOnlyGraphAggregate
 
-READ_RULES_CALLS = Summary('read_rules_calls', 'calls')
+INFER_CALLS = Summary('read_rules_calls', 'calls')
 
 ROOM = Namespace("http://projects.bigasterisk.com/room/")
 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
 
+# Graph() makes a BNode if you don't pass
+# identifier, which can be a bottleneck.
+GRAPH_ID = URIRef('dont/care')
 
 class EvaluationFailed(ValueError):
     """e.g. we were given (5 math:greaterThan 6)"""
@@ -65,7 +69,7 @@
 
     def applyFunctions(self, lhs) -> Graph:
         """may grow the binding with some results"""
-        usedByFuncs = Graph()
+        usedByFuncs = Graph(identifier=GRAPH_ID)
         while True:
             delta = self._applyFunctionsIteration(lhs, usedByFuncs)
             if delta == 0:
@@ -135,6 +139,7 @@
 @dataclass
 class Lhs:
     graph: Graph
+    stats: Dict
 
     staticRuleStmts: Graph = field(default_factory=Graph)
     lhsBindables: Set[BindableTerm] = field(default_factory=set)
@@ -164,6 +169,7 @@
 
         log.debug(f'{INDENT*3} trying all permutations:')
 
+
         for perm in itertools.product(*orderedValueSets):
             binding = CandidateBinding(dict(zip(orderedVars, perm)))
             log.debug('')
@@ -172,11 +178,15 @@
             try:
                 usedByFuncs = binding.applyFunctions(self)
             except EvaluationFailed:
+                self.stats['permCountFailingEval'] += 1
                 continue
 
             if not binding.verify(self, workingSet, usedByFuncs):
                 log.debug(f'{INDENT*4} this binding did not verify')
+                self.stats['permCountFailingVerify'] += 1
                 continue
+
+            self.stats['permCountSucceeding'] += 1
             yield binding
 
     def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool:
@@ -224,7 +234,7 @@
                 yield v, vals
 
     def graphWithoutEvals(self, binding: CandidateBinding) -> Graph:
-        g = Graph()
+        g = Graph(identifier=GRAPH_ID)
         usedByFuncs = binding.applyFunctions(self)
 
         for stmt in self.graph:
@@ -266,7 +276,7 @@
     # internal, use findEvals
     def __init__(self, operands: List[Node], mainStmt: Triple, otherStmts: Iterable[Triple]) -> None:
         self.operands = operands
-        self.operandsStmts = Graph()
+        self.operandsStmts = Graph(identifier=GRAPH_ID)
         self.operandsStmts += otherStmts  # may grow
         self.operandsStmts.add(mainStmt)
         self.stmt = mainStmt
@@ -328,12 +338,14 @@
                 self.rules.add(stmt)
             # others should go to a default working set?
 
+    @INFER_CALLS.time()
     def infer(self, graph: Graph):
         """
         returns new graph of inferred statements.
         """
         log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:')
-
+        startTime = time.time()
+        self.stats: Dict[str, Union[int,float]] = defaultdict(lambda: 0)
         # everything that is true: the input graph, plus every rule conclusion we can make
         workingSet = Graph()
         workingSet += graph
@@ -343,14 +355,18 @@
 
         bailout_iterations = 100
         delta = 1
+        self.stats['initWorkingSet'] = cast(int, workingSet.__len__())
         while delta > 0 and bailout_iterations > 0:
             log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
             bailout_iterations -= 1
             delta = -len(implied)
             self._iterateAllRules(workingSet, implied)
             delta += len(implied)
+            self.stats['iterations'] += 1
             log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
-        log.info(f'{INDENT*0} Inference done; {len(implied)} stmts implied:')
+        self.stats['timeSpent'] = round(time.time() - startTime, 3)
+        self.stats['impliedStmts'] = len(implied)
+        log.info(f'{INDENT*0} Inference done {dict(self.stats)}. Implied:')
         for st in implied:
             log.info(f'{INDENT*1} {st}')
         return implied
@@ -358,7 +374,7 @@
     def _iterateAllRules(self, workingSet: Graph, implied: Graph):
         for i, r in enumerate(self.rules):
             self._logRuleApplicationHeader(workingSet, i, r)
-            _applyRule(Lhs(r[0]), r[2], workingSet, implied)
+            _applyRule(Lhs(r[0], self.stats), r[2], workingSet, implied, self.stats)
 
     def _logRuleApplicationHeader(self, workingSet, i, r):
         if not log.isEnabledFor(logging.DEBUG):
@@ -375,7 +391,7 @@
         log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}')
 
 
-def _applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph):
+def _applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph, stats: Dict):
     for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])):
         log.debug(f'{INDENT*3} rule has a working binding:')