Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1631:2c85a4f5dd9c
big rewrite of infer() using statements not variables as the things to iterate over
author | drewp@bigasterisk.com |
---|---|
date | Sun, 12 Sep 2021 04:32:52 -0700 |
parents | ea559a846714 |
children | bd79a2941cab |
comparison
equal
deleted
inserted
replaced
1630:b3132cd02686 | 1631:2c85a4f5dd9c |
---|---|
5 import itertools | 5 import itertools |
6 import logging | 6 import logging |
7 import time | 7 import time |
8 from collections import defaultdict | 8 from collections import defaultdict |
9 from dataclasses import dataclass | 9 from dataclasses import dataclass |
10 from typing import Dict, Iterator, List, Set, Tuple, Union, cast | 10 from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast |
11 | 11 |
12 from prometheus_client import Summary, Histogram | 12 from prometheus_client import Histogram, Summary |
13 from rdflib import BNode, Graph, Namespace, URIRef | 13 from rdflib import BNode, Graph, Namespace |
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
15 from rdflib.term import Node, Variable | 15 from rdflib.term import Literal, Node, Variable |
16 | 16 |
17 from candidate_binding import CandidateBinding | 17 from candidate_binding import CandidateBinding |
18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) | 18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) |
19 from lhs_evaluation import Evaluation | 19 from lhs_evaluation import Decimal, Evaluation, numericNode |
20 | 20 |
21 log = logging.getLogger('infer') | 21 log = logging.getLogger('infer') |
22 INDENT = ' ' | 22 INDENT = ' ' |
23 | 23 |
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') | 24 INFER_CALLS = Summary('inference_infer_calls', 'calls') |
27 ROOM = Namespace("http://projects.bigasterisk.com/room/") | 27 ROOM = Namespace("http://projects.bigasterisk.com/room/") |
28 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') | 28 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') |
29 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') | 29 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') |
30 | 30 |
31 | 31 |
32 def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]: | |
33 return ( | |
34 None if isinstance(stmt[0], (Variable, BNode)) else stmt[0], | |
35 None if isinstance(stmt[1], (Variable, BNode)) else stmt[1], | |
36 None if isinstance(stmt[2], (Variable, BNode)) else stmt[2], | |
37 ) | |
38 | |
39 | |
40 class NoOptions(ValueError): | |
41 """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply""" | |
42 | |
43 | |
44 class Inconsistent(ValueError): | |
45 """adding this stmt would be inconsistent with an existing binding""" | |
46 | |
47 | |
48 @dataclass | |
49 class StmtLooper: | |
50 lhsStmt: Triple | |
51 prev: Optional['StmtLooper'] | |
52 workingSet: ReadOnlyWorkingSet | |
53 | |
54 def __repr__(self): | |
55 return f'StmtLooper({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' | |
56 | |
57 def __post_init__(self): | |
58 self._myWorkingSetMatches = self._myMatches(self.workingSet) | |
59 | |
60 self._current = CandidateBinding({}) | |
61 self._pastEnd = False | |
62 self._seenBindings: List[Dict[BindableTerm, Node]] = [] | |
63 self.restart() | |
64 | |
65 def _myMatches(self, g: Graph) -> List[Triple]: | |
66 template = stmtTemplate(self.lhsStmt) | |
67 | |
68 stmts = sorted(cast(Iterator[Triple], list(g.triples(template)))) | |
69 # plus new lhs possibilties... | |
70 # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}') | |
71 | |
72 return stmts | |
73 | |
74 def _prevBindings(self) -> Dict[BindableTerm, Node]: | |
75 if not self.prev or self.prev.pastEnd(): | |
76 return {} | |
77 | |
78 return self.prev.currentBinding().binding | |
79 | |
80 def advance(self): | |
81 """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode""" | |
82 log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements') | |
83 for i, stmt in enumerate(self._myWorkingSetMatches): | |
84 try: | |
85 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) | |
86 except Inconsistent: | |
87 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') | |
88 continue | |
89 log.debug(f'seen {outBinding.binding} in {self._seenBindings}') | |
90 if outBinding.binding not in self._seenBindings: | |
91 self._seenBindings.append(outBinding.binding.copy()) | |
92 log.debug(f'no, adding') | |
93 self._current = outBinding | |
94 log.debug(f'{INDENT*7} {self} - Looper matches {stmt} which tells us {outBinding}') | |
95 return | |
96 log.debug(f'yes we saw') | |
97 | |
98 log.debug(f'{INDENT*6} {self} mines rules') | |
99 | |
100 if self.lhsStmt[1] == ROOM['asFarenheit']: | |
101 pb: Dict[BindableTerm, Node] = self._prevBindings() | |
102 if self.lhsStmt[0] in pb: | |
103 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] | |
104 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) | |
105 objVar = self.lhsStmt[2] | |
106 if not isinstance(objVar, Variable): | |
107 raise TypeError(f'expected Variable, got {objVar!r}') | |
108 newBindings = {cast(BindableTerm, objVar): cast(Node, f)} | |
109 self._current.addNewBindings(CandidateBinding(newBindings)) | |
110 if newBindings not in self._seenBindings: | |
111 self._seenBindings.append(newBindings) | |
112 self._current = CandidateBinding(newBindings) | |
113 | |
114 log.debug(f'{INDENT*6} {self} is past end') | |
115 self._pastEnd = True | |
116 | |
117 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: | |
118 outBinding = self._prevBindings().copy() | |
119 for rt, ct in zip(self.lhsStmt, newStmt): | |
120 if isinstance(rt, (Variable, BNode)): | |
121 if rt in outBinding and outBinding[rt] != ct: | |
122 raise Inconsistent() | |
123 outBinding[rt] = ct | |
124 return CandidateBinding(outBinding) | |
125 | |
126 def currentBinding(self) -> CandidateBinding: | |
127 if self.pastEnd(): | |
128 raise NotImplementedError() | |
129 return self._current | |
130 | |
131 def newLhsStmts(self) -> List[Triple]: | |
132 """under the curent bindings, what new stmts beyond workingSet are also true? includes all `prev`""" | |
133 return [] | |
134 | |
135 def pastEnd(self) -> bool: | |
136 return self._pastEnd | |
137 | |
138 def restart(self): | |
139 self._pastEnd = False | |
140 self._seenBindings = [] | |
141 self.advance() | |
142 if self.pastEnd(): | |
143 raise NoOptions() | |
144 | |
145 | |
32 @dataclass | 146 @dataclass |
33 class Lhs: | 147 class Lhs: |
34 graph: Graph | 148 graph: Graph |
35 | 149 |
36 def __post_init__(self): | 150 def __post_init__(self): |
37 # do precomputation in here that's not specific to the workingSet | 151 # do precomputation in here that's not specific to the workingSet |
38 self.staticRuleStmts = Graph() | 152 # self.staticRuleStmts = Graph() |
39 self.nonStaticRuleStmts = Graph() | 153 # self.nonStaticRuleStmts = Graph() |
40 | 154 |
41 self.lhsBindables: Set[BindableTerm] = set() | 155 # self.lhsBindables: Set[BindableTerm] = set() |
42 self.lhsBnodes: Set[BNode] = set() | 156 # self.lhsBnodes: Set[BNode] = set() |
43 for ruleStmt in self.graph: | 157 # for ruleStmt in self.graph: |
44 varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] | 158 # varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] |
45 self.lhsBindables.update(varsAndBnodesInStmt) | 159 # self.lhsBindables.update(varsAndBnodesInStmt) |
46 self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) | 160 # self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) |
47 if not varsAndBnodesInStmt: | 161 # if not varsAndBnodesInStmt: |
48 self.staticRuleStmts.add(ruleStmt) | 162 # self.staticRuleStmts.add(ruleStmt) |
49 else: | 163 # else: |
50 self.nonStaticRuleStmts.add(ruleStmt) | 164 # self.nonStaticRuleStmts.add(ruleStmt) |
51 | 165 |
52 self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts) | 166 # self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts) |
53 | 167 |
54 self.evaluations = list(Evaluation.findEvals(self.graph)) | 168 self.evaluations = list(Evaluation.findEvals(self.graph)) |
55 | 169 |
56 def __repr__(self): | 170 def __repr__(self): |
57 return f"Lhs({graphDump(self.graph)})" | 171 return f"Lhs({graphDump(self.graph)})" |
58 | 172 |
59 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: | 173 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: |
60 """bindings that fit the LHS of a rule, using statements from workingSet and functions | 174 """bindings that fit the LHS of a rule, using statements from workingSet and functions |
61 from LHS""" | 175 from LHS""" |
62 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') | 176 log.debug(f'{INDENT*4} build new StmtLooper stack') |
63 stats['findCandidateBindingsCalls'] += 1 | 177 |
64 | 178 stmtStack: List[StmtLooper] = [] |
65 if not self._allStaticStatementsMatch(knownTrue): | 179 try: |
66 stats['findCandidateBindingEarlyExits'] += 1 | 180 prev: Optional[StmtLooper] = None |
181 for s in sorted(self.graph): # order of this matters! :( | |
182 stmtStack.append(StmtLooper(s, prev, knownTrue)) | |
183 prev = stmtStack[-1] | |
184 except NoOptions: | |
185 log.debug(f'{INDENT*5} no options; 0 bindings') | |
67 return | 186 return |
68 | 187 |
69 for binding in self._possibleBindings(knownTrue, stats): | 188 log.debug(f'{INDENT*5} initial odometer:') |
70 log.debug('') | 189 for l in stmtStack: |
71 log.debug(f'{INDENT*4}*trying {binding.binding}') | 190 log.debug(f'{INDENT*6} {l}') |
72 | 191 |
73 if not binding.verify(knownTrue): | 192 if any(ring.pastEnd() for ring in stmtStack): |
74 log.debug(f'{INDENT*4} this binding did not verify') | 193 log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}') |
75 stats['permCountFailingVerify'] += 1 | 194 |
76 continue | 195 raise NoOptions() |
77 | 196 sl = stmtStack[-1] |
78 stats['permCountSucceeding'] += 1 | 197 iterCount = 0 |
79 yield binding | 198 while True: |
199 iterCount += 1 | |
200 if iterCount > 10: | |
201 raise ValueError('stuck') | |
202 | |
203 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') | |
204 | |
205 log.debug(f'{INDENT*5} <<<') | |
206 yield BoundLhs(self, sl.currentBinding()) | |
207 log.debug(f'{INDENT*5} >>>') | |
208 | |
209 log.debug(f'{INDENT*5} odometer:') | |
210 for l in stmtStack: | |
211 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') | |
212 | |
213 done = self._advanceAll(stmtStack) | |
214 | |
215 log.debug(f'{INDENT*5} odometer after ({done=}):') | |
216 for l in stmtStack: | |
217 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') | |
218 | |
219 log.debug(f'{INDENT*4} ^^ findCandBindings iteration done') | |
220 if done: | |
221 break | |
222 | |
223 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: | |
224 carry = True # 1st elem always must advance | |
225 for i, ring in enumerate(stmtStack): | |
226 # unlike normal odometer, advancing any earlier ring could invalidate later ones | |
227 if carry: | |
228 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance') | |
229 ring.advance() | |
230 carry = False | |
231 if ring.pastEnd(): | |
232 if ring is stmtStack[-1]: | |
233 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} says we done') | |
234 return True | |
235 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart') | |
236 ring.restart() | |
237 carry = True | |
238 return False | |
80 | 239 |
81 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: | 240 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: |
82 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's | 241 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's |
83 # static stmt is matched by a non-static stmt in the rule itself | 242 # static stmt is matched by a non-static stmt in the rule itself |
84 for ruleStmt in self.staticRuleStmts: | 243 for ruleStmt in self.staticRuleStmts: |
160 log.debug(f'{INDENT*4} {v!r} could be:') | 319 log.debug(f'{INDENT*4} {v!r} could be:') |
161 for val in vals: | 320 for val in vals: |
162 log.debug(f'{INDENT*5}{val!r}') | 321 log.debug(f'{INDENT*5}{val!r}') |
163 | 322 |
164 | 323 |
165 # @dataclass | |
166 # class CandidateTermMatches: | |
167 # """lazily find the possible matches for this term""" | |
168 # terms: List[BindableTerm] | |
169 # lhs: Lhs | |
170 # knownTrue: Graph | |
171 # boundSoFar: CandidateBinding | |
172 | |
173 # def __post_init__(self): | |
174 # self.results: List[Node] = [] # we have to be able to repeat the results | |
175 | |
176 # res: Set[Node] = set() | |
177 # for trueStmt in self.knownTrue: # all bound | |
178 # lStmts = list(self.lhsStmtsContainingTerm()) | |
179 # log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}') | |
180 # for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False): | |
181 # log.debug(f'{INDENT*4} {pat=}') | |
182 # implied = self._stmtImplies(pat, trueStmt) | |
183 # if implied is not None: | |
184 # res.add(implied) | |
185 # self.results = list(res) | |
186 # # self.results.sort() | |
187 | |
188 # log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}') | |
189 | |
190 # def lhsStmtsContainingTerm(self): | |
191 # # lhs could precompute this | |
192 # for lhsStmt in self.lhs.graph: | |
193 # if self.term in lhsStmt: | |
194 # yield lhsStmt | |
195 | |
196 # def __iter__(self): | |
197 # return iter(self.results) | |
198 | |
199 | |
200 @dataclass | 324 @dataclass |
201 class BoundLhs: | 325 class BoundLhs: |
202 lhs: Lhs | 326 lhs: Lhs |
203 binding: CandidateBinding | 327 binding: CandidateBinding |
204 | 328 |
205 def __post_init__(self): | 329 def __post_init__(self): |
206 self.usedByFuncs = Graph() | 330 self.usedByFuncs = Graph() |
207 self._applyFunctions() | 331 # self._applyFunctions() |
208 | 332 |
209 def lhsStmtsWithoutEvals(self): | 333 def lhsStmtsWithoutEvals(self): |
210 for stmt in self.lhs.graph: | 334 for stmt in self.lhs.graph: |
211 if stmt in self.usedByFuncs: | 335 if stmt in self.usedByFuncs: |
212 continue | 336 continue |
261 | 385 |
262 @dataclass | 386 @dataclass |
263 class Rule: | 387 class Rule: |
264 lhsGraph: Graph | 388 lhsGraph: Graph |
265 rhsGraph: Graph | 389 rhsGraph: Graph |
266 | 390 |
267 def __post_init__(self): | 391 def __post_init__(self): |
268 self.lhs = Lhs(self.lhsGraph) | 392 self.lhs = Lhs(self.lhsGraph) |
393 # | |
394 self.rhsBnodeMap = {} | |
269 | 395 |
270 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): | 396 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): |
271 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): | 397 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): |
272 log.debug(f'{INDENT*3} rule has a working binding:') | 398 log.debug(f'{INDENT*5} +rule has a working binding: {bound}') |
273 | 399 |
274 for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): | 400 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do |
275 log.debug(f'{INDENT*4} adding {lhsBoundStmt=}') | 401 existingRhsBnodes = set() |
276 workingSet.add(lhsBoundStmt) | 402 for stmt in self.rhsGraph: |
403 for t in stmt: | |
404 if isinstance(t, BNode): | |
405 existingRhsBnodes.add(t) | |
406 # if existingRhsBnodes: | |
407 # log.debug(f'{INDENT*6} mapping rhs bnodes {existingRhsBnodes} to new ones') | |
408 | |
409 for b in existingRhsBnodes: | |
410 | |
411 key = tuple(sorted(bound.binding.binding.items())), b | |
412 self.rhsBnodeMap.setdefault(key, BNode()) | |
413 | |
414 | |
415 bound.binding.addNewBindings(CandidateBinding({b: self.rhsBnodeMap[key]})) | |
416 | |
417 # for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): | |
418 # log.debug(f'{INDENT*6} adding to workingSet {lhsBoundStmt=}') | |
419 # workingSet.add(lhsBoundStmt) | |
420 # log.debug(f'{INDENT*6} rhsGraph is good: {list(self.rhsGraph)}') | |
421 | |
277 for newStmt in bound.binding.apply(self.rhsGraph): | 422 for newStmt in bound.binding.apply(self.rhsGraph): |
278 log.debug(f'{INDENT*4} adding {newStmt=}') | 423 # log.debug(f'{INDENT*6} adding {newStmt=}') |
279 workingSet.add(newStmt) | 424 workingSet.add(newStmt) |
280 implied.add(newStmt) | 425 implied.add(newStmt) |
281 | 426 |
282 | 427 |
283 class Inference: | 428 class Inference: |
348 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') | 493 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') |
349 | 494 |
350 | 495 |
351 def graphDump(g: Union[Graph, List[Triple]]): | 496 def graphDump(g: Union[Graph, List[Triple]]): |
352 if not isinstance(g, Graph): | 497 if not isinstance(g, Graph): |
353 log.warning(f"it's a {type(g)}") | |
354 g2 = Graph() | 498 g2 = Graph() |
355 g2 += g | 499 g2 += g |
356 g = g2 | 500 g = g2 |
357 g.bind('', ROOM) | 501 g.bind('', ROOM) |
358 g.bind('ex', Namespace('http://example.com/')) | 502 g.bind('ex', Namespace('http://example.com/')) |