changeset 1648:3059f31b2dfa

more performance work
author drewp@bigasterisk.com
date Fri, 17 Sep 2021 11:10:18 -0700
parents 34eb87f68ab8
children bb5d2b5370ac
files service/mqtt_to_rdf/infer_perf_test.py service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/lhs_evaluation.py
diffstat 3 files changed, 140 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/service/mqtt_to_rdf/infer_perf_test.py	Fri Sep 17 11:10:18 2021 -0700
@@ -0,0 +1,50 @@
+import logging
+import unittest
+
+from rdflib.graph import ConjunctiveGraph
+
+from inference import Inference
+from inference_test import N3
+
+logging.basicConfig(level=logging.INFO)
+
+# ~/.venvs/mqtt_to_rdf/bin/nosetests --with-watcher --logging-level=INFO --with-timer -s --nologcapture infer_perf_test
+
+
+class TestPerf(unittest.TestCase):
+
+    def test(self):
+        config = ConjunctiveGraph()
+        config.parse('conf/rules.n3', format='n3')
+
+        inference = Inference()
+        inference.setRules(config)
+        expandedConfig = inference.infer(config)
+        expandedConfig += inference.nonRuleStatements()
+
+        for loop in range(10):
+            # g = N3('''
+            # <urn:uuid:2f5bbe1e-177f-11ec-9f97-8a12f6515350> a :MqttMessage ;
+            #     :body "online" ;   
+            #     :onlineTerm :Online ;
+            #     :topic ( "frontdoorlock" "status") .
+            # ''')
+            # derived = inference.infer(g)
+
+            # g = N3('''
+            # <urn:uuid:2f5bbe1e-177f-11ec-9f97-8a12f6515350> a :MqttMessage ;
+            #     :body "zz" ;   
+            #     :bodyFloat 12.2;
+            #     :onlineTerm :Online ;
+            #     :topic ( "air_quality_outdoor" "sensor" "bme280_temperature" "state") .
+            # ''')
+            # derived = inference.infer(g)
+            g = N3('''
+            <urn:uuid:a4778502-1784-11ec-a323-464f081581c1> a :MqttMessage ;
+                :body "65021" ;
+                :bodyFloat 6.5021e+04 ;
+                :topic ( "air_quality_indoor" "sensor" "ccs811_total_volatile_organic_compound" "state" ) .
+            ''')
+            derived = inference.infer(g)
+
+        # self.fail()
--- a/service/mqtt_to_rdf/inference.py	Fri Sep 17 11:07:21 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Fri Sep 17 11:10:18 2021 -0700
@@ -61,7 +61,7 @@
     parent: 'Lhs'  # just for lhs.graph, really
 
     def __repr__(self):
-        return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
+        return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})'
 
     def __post_init__(self):
         self._shortId = next(_stmtLooperShortId)
@@ -70,6 +70,9 @@
         self._current = CandidateBinding({})
         self._pastEnd = False
         self._seenBindings: List[CandidateBinding] = []
+
+        log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})')
+
         self.restart()
 
     def _myMatches(self, g: Graph) -> List[Triple]:
@@ -119,7 +122,7 @@
             try:
                 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
             except Inconsistent:
-                log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
+                log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings')
                 continue
 
             log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
@@ -169,7 +172,8 @@
         for rt, ct in zip(self.lhsStmt, newStmt):
             if isinstance(rt, (Variable, BNode)):
                 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct:
-                    raise Inconsistent(f'{rt=} {ct=} {outBinding=}')
+                    msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else ''
+                    raise Inconsistent(msg)
                 outBinding.addNewBindings(CandidateBinding({rt: ct}))
         return outBinding
 
@@ -191,15 +195,37 @@
 
 @dataclass
 class Lhs:
-    graph: Graph
+    graph: Graph  # our full LHS graph, as input. See below for the statements partitioned into groups.
 
     def __post_init__(self):
-        pass
+
+        usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
+
+        stmtsToMatch = list(self.graph - usedByFuncs)
+        self.staticStmts = []
+        self.patternStmts = []
+        for st in stmtsToMatch:
+            if all(isinstance(term, (URIRef, Literal)) for term in st):
+                self.staticStmts.append(st)
+            else:
+                self.patternStmts.append(st)
+
+        # sort them by variable dependencies; don't just try all perms!
+        def lightSortKey(stmt):  # Not this.
+            (s, p, o) = stmt
+            return p in rulePredicates(), p, s, o
+
+        self.patternStmts.sort(key=lightSortKey)
+
+        self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
+        self.myPreds -= rulePredicates()
+        self.myPreds -= {RDF.first, RDF.rest}
+        self.myPreds = set(self.myPreds)
 
     def __repr__(self):
         return f"Lhs({graphDump(self.graph)})"
 
-    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
+    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
         if self.graph.__len__() == 0:
@@ -211,10 +237,19 @@
             stats['_checkPredicateCountsCulls'] += 1
             return
 
+        if not all(st in knownTrue for st in self.staticStmts):
+            stats['staticStmtCulls'] += 1
+            return
+
+        if len(self.patternStmts) == 0:
+            # static only
+            yield BoundLhs(self, CandidateBinding({}))
+            return
+
         log.debug(f'{INDENT*4} build new StmtLooper stack')
 
         try:
-            stmtStack = self._assembleRings(knownTrue)
+            stmtStack = self._assembleRings(knownTrue, stats)
         except NoOptions:
             log.debug(f'{INDENT*5} start up with no options; 0 bindings')
             return
@@ -225,8 +260,8 @@
         iterCount = 0
         while True:
             iterCount += 1
-            if iterCount > 10:
-                raise ValueError('stuck')
+            if iterCount > ruleStatementsIterationLimit:
+                raise ValueError('rule too complex')
 
             log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
 
@@ -249,32 +284,23 @@
 
     def _checkPredicateCounts(self, knownTrue):
         """raise NoOptions quickly in some cases"""
-        myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
-        myPreds -= rulePredicates()
-        myPreds -= {RDF.first, RDF.rest}
-        if any((None, p, None) not in knownTrue for p in set(myPreds)):
+
+        if any((None, p, None) not in knownTrue for p in self.myPreds):
             return True
+        log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
         return False
 
-    def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]:
+    def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]:
         """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
 
-        usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
-
-        stmtsToAdd = list(self.graph - usedByFuncs)
-
-        # sort them by variable dependencies; don't just try all perms!
-        def lightSortKey(stmt):  # Not this. Though it helps performance on the big rdf list cases.
-            (s, p, o) = stmt
-            return p == MATH['sum'], p, s, o
-
-        stmtsToAdd.sort(key=lightSortKey)
-
-        for i, perm in enumerate(itertools.permutations(stmtsToAdd)):
+        log.info(f'{INDENT*2} stats={dict(stats)}')
+        log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}')
+        for i, perm in enumerate(itertools.permutations(self.patternStmts)):
             stmtStack: List[StmtLooper] = []
             prev: Optional[StmtLooper] = None
-            log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
+            if log.isEnabledFor(logging.DEBUG):
+                log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
 
             for s in perm:
                 try:
@@ -286,8 +312,10 @@
                 prev = stmtStack[-1]
             else:
                 return stmtStack
+            if i > 5000:
+                raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}')
+
         log.debug(f'{INDENT*6} no perms worked- rule cannot match anything')
-
         raise NoOptions()
 
     def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
@@ -329,8 +357,8 @@
         #
         self.rhsBnodeMap = {}
 
-    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict):
-        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
+    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit):
+        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit):
             log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
 
             # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
@@ -362,13 +390,18 @@
                 implied.add(newStmt)
 
 
+@dataclass
 class Inference:
+    rulesIterationLimit = 3
+    ruleStatementsIterationLimit = 3
 
     def __init__(self) -> None:
-        self.rules = []
+        self.rules: List[Rule] = []
+        self._nonRuleStmts: List[Triple] = []
 
     def setRules(self, g: ConjunctiveGraph):
-        self.rules: List[Rule] = []
+        self.rules = []
+        self._nonRuleStmts = []
         for stmt in g:
             if stmt[1] == LOG['implies']:
                 self.rules.append(Rule(stmt[0], stmt[2]))
@@ -391,28 +424,30 @@
         # just the statements that came from RHS's of rules that fired.
         implied = ConjunctiveGraph()
 
-        bailout_iterations = 100
+        rulesIterations = 0
         delta = 1
         stats['initWorkingSet'] = cast(int, workingSet.__len__())
-        while delta > 0 and bailout_iterations > 0:
+        while delta > 0 and rulesIterations <= self.rulesIterationLimit:
             log.debug('')
-            log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
-            bailout_iterations -= 1
+            log.info(f'{INDENT*1}*iteration {rulesIterations}')
+
             delta = -len(implied)
             self._iterateAllRules(workingSet, implied, stats)
             delta += len(implied)
-            stats['iterations'] += 1
+            rulesIterations += 1
             log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
+        stats['iterations'] = rulesIterations
         stats['timeSpent'] = round(time.time() - startTime, 3)
         stats['impliedStmts'] = len(implied)
-        log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:')
-        log.info(graphDump(implied))
+        log.info(f'{INDENT*0} Inference done {dict(stats)}.')
+        log.debug('Implied:')
+        log.debug(graphDump(implied))
         return implied
 
     def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
         for i, rule in enumerate(self.rules):
             self._logRuleApplicationHeader(workingSet, i, rule)
-            rule.applyRule(workingSet, implied, stats)
+            rule.applyRule(workingSet, implied, stats, self.ruleStatementsIterationLimit)
 
     def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
         if not log.isEnabledFor(logging.DEBUG):
@@ -420,8 +455,9 @@
 
         log.debug('')
         log.debug(f'{INDENT*2} workingSet:')
-        for j, stmt in enumerate(sorted(workingSet)):
-            log.debug(f'{INDENT*3} ({j}) {stmt}')
+        # for j, stmt in enumerate(sorted(workingSet)):
+        #     log.debug(f'{INDENT*3} ({j}) {stmt}')
+        log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}')
 
         log.debug('')
         log.debug(f'{INDENT*2}-applying rule {i}')
@@ -431,7 +467,7 @@
         log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
 
 
-def graphDump(g: Union[Graph, List[Triple]]):
+def graphDump(g: Union[Graph, List[Triple]], oneLine=True):
     # this is very slow- debug only!
     if not log.isEnabledFor(logging.DEBUG):
         return "(skipped dump)"
@@ -442,5 +478,7 @@
     g.bind('', ROOM)
     g.bind('ex', Namespace('http://example.com/'))
     lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
-    lines = [line.strip() for line in lines if not line.startswith('@prefix')]
+    lines = [line for line in lines if not line.startswith('@prefix')]
+    if oneLine:
+        lines = [line.strip() for line in lines]
     return ' '.join(lines)
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Fri Sep 17 11:07:21 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Fri Sep 17 11:10:18 2021 -0700
@@ -7,7 +7,7 @@
 from prometheus_client import Summary
 from rdflib import RDF, Literal, Namespace, URIRef
 from rdflib.graph import Graph
-from rdflib.term import Node, Variable
+from rdflib.term import BNode, Node, Variable
 
 from inference_types import BindableTerm, Triple
 
@@ -89,6 +89,11 @@
             raise TypeError(f'expected Variable, got {objVar!r}')
         return CandidateBinding({cast(BindableTerm, objVar): value})
 
+    def usedStatements(self) -> Set[Triple]:
+        '''stmts in self.graph (not including self.stmt, oddly) that are part of
+        this function setup and aren't to be matched literally'''
+        return set()
+    
 
 class SubjectFunction(Function):
     """function that depends only on the subject term"""
@@ -144,7 +149,7 @@
         f = Literal(sum(self.getNumericOperands(existingBinding)))
         return self.valueInObjectTerm(f)
 
-### registeration is done
+### registration is done
 
 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes)
 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]:
@@ -158,8 +163,7 @@
     usedByFuncs: Set[Triple] = set()  # don't worry about matching these
     for s in graph:
         for cls in functionsFor(pred=s[1]):
-            if issubclass(cls, ListFunction):
-                usedByFuncs.update(cls(s, graph).usedStatements())
+            usedByFuncs.update(cls(s, graph).usedStatements())
     return usedByFuncs