Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1633:6107603ed455
fix farenheit rule case, fix some others that depend on rings order, but this breaks some performance because of itertools.perm
author | drewp@bigasterisk.com |
---|---|
date | Sun, 12 Sep 2021 21:48:36 -0700 |
parents | bd79a2941cab |
children | ba59cfc3c747 |
comparison
equal
deleted
inserted
replaced
1632:bd79a2941cab | 1633:6107603ed455 |
---|---|
5 import itertools | 5 import itertools |
6 import logging | 6 import logging |
7 import time | 7 import time |
8 from collections import defaultdict | 8 from collections import defaultdict |
9 from dataclasses import dataclass | 9 from dataclasses import dataclass |
10 from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast | 10 from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast |
11 | 11 |
12 from prometheus_client import Histogram, Summary | 12 from prometheus_client import Histogram, Summary |
13 from rdflib import BNode, Graph, Namespace | 13 from rdflib import BNode, Graph, Namespace |
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
15 from rdflib.term import Literal, Node, Variable | 15 from rdflib.term import Literal, Node, Variable |
86 | 86 |
87 return self.prev.currentBinding().binding | 87 return self.prev.currentBinding().binding |
88 | 88 |
89 def advance(self): | 89 def advance(self): |
90 """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode""" | 90 """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode""" |
91 log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements') | 91 if self._pastEnd: |
92 for i, stmt in enumerate(self._myWorkingSetMatches): | 92 raise NotImplementedError('need restart') |
93 log.debug('') | |
94 augmentedWorkingSet: Sequence[Triple] = [] | |
95 if self.prev is None: | |
96 augmentedWorkingSet = self._myWorkingSetMatches | |
97 else: | |
98 augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, | |
99 returnBoundStatementsOnly=False)) | |
100 | |
101 log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}') | |
102 | |
103 log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements') | |
104 for s in augmentedWorkingSet: | |
105 log.debug(f'{INDENT*7} {s}') | |
106 | |
107 for i, stmt in enumerate(augmentedWorkingSet): | |
93 try: | 108 try: |
94 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) | 109 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) |
95 except Inconsistent: | 110 except Inconsistent: |
96 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') | 111 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') |
97 continue | 112 continue |
119 newBindings = {cast(BindableTerm, objVar): cast(Node, f)} | 134 newBindings = {cast(BindableTerm, objVar): cast(Node, f)} |
120 self._current.addNewBindings(CandidateBinding(newBindings)) | 135 self._current.addNewBindings(CandidateBinding(newBindings)) |
121 if newBindings not in self._seenBindings: | 136 if newBindings not in self._seenBindings: |
122 self._seenBindings.append(newBindings) | 137 self._seenBindings.append(newBindings) |
123 self._current = CandidateBinding(newBindings) | 138 self._current = CandidateBinding(newBindings) |
139 return | |
124 | 140 |
125 log.debug(f'{INDENT*6} {self} is past end') | 141 log.debug(f'{INDENT*6} {self} is past end') |
126 self._pastEnd = True | 142 self._pastEnd = True |
127 | 143 |
128 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: | 144 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: |
182 return f"Lhs({graphDump(self.graph)})" | 198 return f"Lhs({graphDump(self.graph)})" |
183 | 199 |
184 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: | 200 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: |
185 """bindings that fit the LHS of a rule, using statements from workingSet and functions | 201 """bindings that fit the LHS of a rule, using statements from workingSet and functions |
186 from LHS""" | 202 from LHS""" |
203 if self.graph.__len__() == 0: | |
204 # special case- no LHS! | |
205 yield BoundLhs(self, CandidateBinding({})) | |
206 return | |
207 | |
187 log.debug(f'{INDENT*4} build new StmtLooper stack') | 208 log.debug(f'{INDENT*4} build new StmtLooper stack') |
188 | 209 |
189 stmtStack: List[StmtLooper] = [] | |
190 try: | 210 try: |
191 prev: Optional[StmtLooper] = None | 211 stmtStack = self._assembleRings(knownTrue) |
192 for s in sorted(self.graph): # order of this matters! :( | |
193 stmtStack.append(StmtLooper(s, prev, knownTrue)) | |
194 prev = stmtStack[-1] | |
195 except NoOptions: | 212 except NoOptions: |
196 log.debug(f'{INDENT*5} start up with no options; 0 bindings') | 213 log.debug(f'{INDENT*5} start up with no options; 0 bindings') |
197 return | 214 return |
198 self._debugStmtStack('initial odometer', stmtStack) | 215 self._debugStmtStack('initial odometer', stmtStack) |
199 | 216 self._assertAllRingsAreValid(stmtStack) |
200 | 217 |
201 if any(ring.pastEnd() for ring in stmtStack): | 218 lastRing = stmtStack[-1] |
202 log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}') | |
203 | |
204 raise NoOptions() | |
205 sl = stmtStack[-1] | |
206 iterCount = 0 | 219 iterCount = 0 |
207 while True: | 220 while True: |
208 iterCount += 1 | 221 iterCount += 1 |
209 if iterCount > 10: | 222 if iterCount > 10: |
210 raise ValueError('stuck') | 223 raise ValueError('stuck') |
211 | 224 |
212 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') | 225 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') |
213 | 226 |
214 yield BoundLhs(self, sl.currentBinding()) | 227 yield BoundLhs(self, lastRing.currentBinding()) |
215 | 228 |
216 self._debugStmtStack('odometer', stmtStack) | 229 self._debugStmtStack('odometer', stmtStack) |
217 | 230 |
218 done = self._advanceAll(stmtStack) | 231 done = self._advanceAll(stmtStack) |
219 | 232 |
225 | 238 |
226 def _debugStmtStack(self, label, stmtStack): | 239 def _debugStmtStack(self, label, stmtStack): |
227 log.debug(f'{INDENT*5} {label}:') | 240 log.debug(f'{INDENT*5} {label}:') |
228 for l in stmtStack: | 241 for l in stmtStack: |
229 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') | 242 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') |
243 | |
244 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: | |
245 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all | |
246 start out valid (or else raise NoOptions)""" | |
247 | |
248 stmtsToAdd = list(self.graph) | |
249 | |
250 for perm in itertools.permutations(stmtsToAdd): | |
251 stmtStack: List[StmtLooper] = [] | |
252 prev: Optional[StmtLooper] = None | |
253 log.debug(f'{INDENT*5} try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') | |
254 | |
255 for s in perm: | |
256 try: | |
257 elem = StmtLooper(s, prev, knownTrue) | |
258 except NoOptions: | |
259 log.debug(f'{INDENT*6} permutation didnt work, try another') | |
260 break | |
261 stmtStack.append(elem) | |
262 prev = stmtStack[-1] | |
263 else: | |
264 return stmtStack | |
265 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything') | |
266 | |
267 raise NoOptions() | |
230 | 268 |
231 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: | 269 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: |
232 carry = True # 1st elem always must advance | 270 carry = True # 1st elem always must advance |
233 for i, ring in enumerate(stmtStack): | 271 for i, ring in enumerate(stmtStack): |
234 # unlike normal odometer, advancing any earlier ring could invalidate later ones | 272 # unlike normal odometer, advancing any earlier ring could invalidate later ones |
242 return True | 280 return True |
243 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart') | 281 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart') |
244 ring.restart() | 282 ring.restart() |
245 carry = True | 283 carry = True |
246 return False | 284 return False |
285 | |
286 def _assertAllRingsAreValid(self, stmtStack): | |
287 if any(ring.pastEnd() for ring in stmtStack): # this is an unexpected debug assertion | |
288 log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}') | |
289 raise NoOptions() | |
247 | 290 |
248 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: | 291 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: |
249 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's | 292 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's |
250 # static stmt is matched by a non-static stmt in the rule itself | 293 # static stmt is matched by a non-static stmt in the rule itself |
251 for ruleStmt in self.staticRuleStmts: | 294 for ruleStmt in self.staticRuleStmts: |