diff service/mqtt_to_rdf/inference.py @ 1648:3059f31b2dfa

more performance work
author drewp@bigasterisk.com
date Fri, 17 Sep 2021 11:10:18 -0700
parents 5403c6343fa4
children bb5d2b5370ac
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Fri Sep 17 11:07:21 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Fri Sep 17 11:10:18 2021 -0700
@@ -61,7 +61,7 @@
     parent: 'Lhs'  # just for lhs.graph, really
 
     def __repr__(self):
-        return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
+        return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})'
 
     def __post_init__(self):
         self._shortId = next(_stmtLooperShortId)
@@ -70,6 +70,9 @@
         self._current = CandidateBinding({})
         self._pastEnd = False
         self._seenBindings: List[CandidateBinding] = []
+
+        log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})')
+
         self.restart()
 
     def _myMatches(self, g: Graph) -> List[Triple]:
@@ -119,7 +122,7 @@
             try:
                 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
             except Inconsistent:
-                log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
+                log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings')
                 continue
 
             log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
@@ -169,7 +172,8 @@
         for rt, ct in zip(self.lhsStmt, newStmt):
             if isinstance(rt, (Variable, BNode)):
                 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct:
-                    raise Inconsistent(f'{rt=} {ct=} {outBinding=}')
+                    msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else ''
+                    raise Inconsistent(msg)
                 outBinding.addNewBindings(CandidateBinding({rt: ct}))
         return outBinding
 
@@ -191,15 +195,37 @@
 
 @dataclass
 class Lhs:
-    graph: Graph
+    graph: Graph  # our full LHS graph, as input. See below for the statements partitioned into groups.
 
     def __post_init__(self):
-        pass
+
+        usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
+
+        stmtsToMatch = list(self.graph - usedByFuncs)
+        self.staticStmts = []
+        self.patternStmts = []
+        for st in stmtsToMatch:
+            if all(isinstance(term, (URIRef, Literal)) for term in st):
+                self.staticStmts.append(st)
+            else:
+                self.patternStmts.append(st)
+
+        # sort them by variable dependencies; don't just try all perms!
+        def lightSortKey(stmt):  # Not this.
+            (s, p, o) = stmt
+            return p in rulePredicates(), p, s, o
+
+        self.patternStmts.sort(key=lightSortKey)
+
+        self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
+        self.myPreds -= rulePredicates()
+        self.myPreds -= {RDF.first, RDF.rest}
+        self.myPreds = set(self.myPreds)
 
     def __repr__(self):
         return f"Lhs({graphDump(self.graph)})"
 
-    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
+    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
         if self.graph.__len__() == 0:
@@ -211,10 +237,19 @@
             stats['_checkPredicateCountsCulls'] += 1
             return
 
+        if not all(st in knownTrue for st in self.staticStmts):
+            stats['staticStmtCulls'] += 1
+            return
+
+        if len(self.patternStmts) == 0:
+            # static only
+            yield BoundLhs(self, CandidateBinding({}))
+            return
+
         log.debug(f'{INDENT*4} build new StmtLooper stack')
 
         try:
-            stmtStack = self._assembleRings(knownTrue)
+            stmtStack = self._assembleRings(knownTrue, stats)
         except NoOptions:
             log.debug(f'{INDENT*5} start up with no options; 0 bindings')
             return
@@ -225,8 +260,8 @@
         iterCount = 0
         while True:
             iterCount += 1
-            if iterCount > 10:
-                raise ValueError('stuck')
+            if iterCount > ruleStatementsIterationLimit:
+                raise ValueError('rule too complex')
 
             log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
 
@@ -249,32 +284,23 @@
 
     def _checkPredicateCounts(self, knownTrue):
         """raise NoOptions quickly in some cases"""
-        myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
-        myPreds -= rulePredicates()
-        myPreds -= {RDF.first, RDF.rest}
-        if any((None, p, None) not in knownTrue for p in set(myPreds)):
+
+        if any((None, p, None) not in knownTrue for p in self.myPreds):
             return True
+        log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
         return False
 
-    def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]:
+    def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]:
         """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
 
-        usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
-
-        stmtsToAdd = list(self.graph - usedByFuncs)
-
-        # sort them by variable dependencies; don't just try all perms!
-        def lightSortKey(stmt):  # Not this. Though it helps performance on the big rdf list cases.
-            (s, p, o) = stmt
-            return p == MATH['sum'], p, s, o
-
-        stmtsToAdd.sort(key=lightSortKey)
-
-        for i, perm in enumerate(itertools.permutations(stmtsToAdd)):
+        log.info(f'{INDENT*2} stats={dict(stats)}')
+        log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}')
+        for i, perm in enumerate(itertools.permutations(self.patternStmts)):
             stmtStack: List[StmtLooper] = []
             prev: Optional[StmtLooper] = None
-            log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
+            if log.isEnabledFor(logging.DEBUG):
+                log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
 
             for s in perm:
                 try:
@@ -286,8 +312,10 @@
                 prev = stmtStack[-1]
             else:
                 return stmtStack
+            if i > 5000:
+                raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}')
+
         log.debug(f'{INDENT*6} no perms worked- rule cannot match anything')
-
         raise NoOptions()
 
     def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
@@ -329,8 +357,8 @@
         #
         self.rhsBnodeMap = {}
 
-    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict):
-        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
+    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit):
+        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit):
             log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
 
             # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
@@ -362,13 +390,18 @@
                 implied.add(newStmt)
 
 
+@dataclass
 class Inference:
+    rulesIterationLimit = 3
+    ruleStatementsIterationLimit = 3
 
     def __init__(self) -> None:
-        self.rules = []
+        self.rules: List[Rule] = []
+        self._nonRuleStmts: List[Triple] = []
 
     def setRules(self, g: ConjunctiveGraph):
-        self.rules: List[Rule] = []
+        self.rules = []
+        self._nonRuleStmts = []
         for stmt in g:
             if stmt[1] == LOG['implies']:
                 self.rules.append(Rule(stmt[0], stmt[2]))
@@ -391,28 +424,30 @@
         # just the statements that came from RHS's of rules that fired.
         implied = ConjunctiveGraph()
 
-        bailout_iterations = 100
+        rulesIterations = 0
         delta = 1
         stats['initWorkingSet'] = cast(int, workingSet.__len__())
-        while delta > 0 and bailout_iterations > 0:
+        while delta > 0 and rulesIterations <= self.rulesIterationLimit:
             log.debug('')
-            log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
-            bailout_iterations -= 1
+            log.info(f'{INDENT*1}*iteration {rulesIterations}')
+
             delta = -len(implied)
             self._iterateAllRules(workingSet, implied, stats)
             delta += len(implied)
-            stats['iterations'] += 1
+            rulesIterations += 1
             log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
+        stats['iterations'] = rulesIterations
         stats['timeSpent'] = round(time.time() - startTime, 3)
         stats['impliedStmts'] = len(implied)
-        log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:')
-        log.info(graphDump(implied))
+        log.info(f'{INDENT*0} Inference done {dict(stats)}.')
+        log.debug('Implied:')
+        log.debug(graphDump(implied))
         return implied
 
     def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
         for i, rule in enumerate(self.rules):
             self._logRuleApplicationHeader(workingSet, i, rule)
-            rule.applyRule(workingSet, implied, stats)
+            rule.applyRule(workingSet, implied, stats, self.ruleStatementsIterationLimit)
 
     def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
         if not log.isEnabledFor(logging.DEBUG):
@@ -420,8 +455,9 @@
 
         log.debug('')
         log.debug(f'{INDENT*2} workingSet:')
-        for j, stmt in enumerate(sorted(workingSet)):
-            log.debug(f'{INDENT*3} ({j}) {stmt}')
+        # for j, stmt in enumerate(sorted(workingSet)):
+        #     log.debug(f'{INDENT*3} ({j}) {stmt}')
+        log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}')
 
         log.debug('')
         log.debug(f'{INDENT*2}-applying rule {i}')
@@ -431,7 +467,7 @@
         log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
 
 
-def graphDump(g: Union[Graph, List[Triple]]):
+def graphDump(g: Union[Graph, List[Triple]], oneLine=True):
     # this is very slow- debug only!
     if not log.isEnabledFor(logging.DEBUG):
         return "(skipped dump)"
@@ -442,5 +478,7 @@
     g.bind('', ROOM)
     g.bind('ex', Namespace('http://example.com/'))
     lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
-    lines = [line.strip() for line in lines if not line.startswith('@prefix')]
+    lines = [line for line in lines if not line.startswith('@prefix')]
+    if oneLine:
+        lines = [line.strip() for line in lines]
     return ' '.join(lines)