comparison service/mqtt_to_rdf/inference.py @ 1631:2c85a4f5dd9c

big rewrite of infer() using statements not variables as the things to iterate over
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 04:32:52 -0700
parents ea559a846714
children bd79a2941cab
comparison
equal deleted inserted replaced
1630:b3132cd02686 1631:2c85a4f5dd9c
5 import itertools 5 import itertools
6 import logging 6 import logging
7 import time 7 import time
8 from collections import defaultdict 8 from collections import defaultdict
9 from dataclasses import dataclass 9 from dataclasses import dataclass
10 from typing import Dict, Iterator, List, Set, Tuple, Union, cast 10 from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast
11 11
12 from prometheus_client import Summary, Histogram 12 from prometheus_client import Histogram, Summary
13 from rdflib import BNode, Graph, Namespace, URIRef 13 from rdflib import BNode, Graph, Namespace
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
15 from rdflib.term import Node, Variable 15 from rdflib.term import Literal, Node, Variable
16 16
17 from candidate_binding import CandidateBinding 17 from candidate_binding import CandidateBinding
18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) 18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
19 from lhs_evaluation import Evaluation 19 from lhs_evaluation import Decimal, Evaluation, numericNode
20 20
21 log = logging.getLogger('infer') 21 log = logging.getLogger('infer')
22 INDENT = ' ' 22 INDENT = ' '
23 23
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') 24 INFER_CALLS = Summary('inference_infer_calls', 'calls')
27 ROOM = Namespace("http://projects.bigasterisk.com/room/") 27 ROOM = Namespace("http://projects.bigasterisk.com/room/")
28 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') 28 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
29 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') 29 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
30 30
31 31
32 def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]:
33 return (
34 None if isinstance(stmt[0], (Variable, BNode)) else stmt[0],
35 None if isinstance(stmt[1], (Variable, BNode)) else stmt[1],
36 None if isinstance(stmt[2], (Variable, BNode)) else stmt[2],
37 )
38
39
40 class NoOptions(ValueError):
41 """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
42
43
44 class Inconsistent(ValueError):
45 """adding this stmt would be inconsistent with an existing binding"""
46
47
48 @dataclass
49 class StmtLooper:
50 lhsStmt: Triple
51 prev: Optional['StmtLooper']
52 workingSet: ReadOnlyWorkingSet
53
54 def __repr__(self):
55 return f'StmtLooper({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
56
57 def __post_init__(self):
58 self._myWorkingSetMatches = self._myMatches(self.workingSet)
59
60 self._current = CandidateBinding({})
61 self._pastEnd = False
62 self._seenBindings: List[Dict[BindableTerm, Node]] = []
63 self.restart()
64
65 def _myMatches(self, g: Graph) -> List[Triple]:
66 template = stmtTemplate(self.lhsStmt)
67
68 stmts = sorted(cast(Iterator[Triple], list(g.triples(template))))
69 # plus new lhs possibilties...
70 # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}')
71
72 return stmts
73
74 def _prevBindings(self) -> Dict[BindableTerm, Node]:
75 if not self.prev or self.prev.pastEnd():
76 return {}
77
78 return self.prev.currentBinding().binding
79
80 def advance(self):
81 """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode"""
82 log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements')
83 for i, stmt in enumerate(self._myWorkingSetMatches):
84 try:
85 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
86 except Inconsistent:
87 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
88 continue
89 log.debug(f'seen {outBinding.binding} in {self._seenBindings}')
90 if outBinding.binding not in self._seenBindings:
91 self._seenBindings.append(outBinding.binding.copy())
92 log.debug(f'no, adding')
93 self._current = outBinding
94 log.debug(f'{INDENT*7} {self} - Looper matches {stmt} which tells us {outBinding}')
95 return
96 log.debug(f'yes we saw')
97
98 log.debug(f'{INDENT*6} {self} mines rules')
99
100 if self.lhsStmt[1] == ROOM['asFarenheit']:
101 pb: Dict[BindableTerm, Node] = self._prevBindings()
102 if self.lhsStmt[0] in pb:
103 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]]
104 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32))
105 objVar = self.lhsStmt[2]
106 if not isinstance(objVar, Variable):
107 raise TypeError(f'expected Variable, got {objVar!r}')
108 newBindings = {cast(BindableTerm, objVar): cast(Node, f)}
109 self._current.addNewBindings(CandidateBinding(newBindings))
110 if newBindings not in self._seenBindings:
111 self._seenBindings.append(newBindings)
112 self._current = CandidateBinding(newBindings)
113
114 log.debug(f'{INDENT*6} {self} is past end')
115 self._pastEnd = True
116
117 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
118 outBinding = self._prevBindings().copy()
119 for rt, ct in zip(self.lhsStmt, newStmt):
120 if isinstance(rt, (Variable, BNode)):
121 if rt in outBinding and outBinding[rt] != ct:
122 raise Inconsistent()
123 outBinding[rt] = ct
124 return CandidateBinding(outBinding)
125
126 def currentBinding(self) -> CandidateBinding:
127 if self.pastEnd():
128 raise NotImplementedError()
129 return self._current
130
131 def newLhsStmts(self) -> List[Triple]:
132 """under the curent bindings, what new stmts beyond workingSet are also true? includes all `prev`"""
133 return []
134
135 def pastEnd(self) -> bool:
136 return self._pastEnd
137
138 def restart(self):
139 self._pastEnd = False
140 self._seenBindings = []
141 self.advance()
142 if self.pastEnd():
143 raise NoOptions()
144
145
32 @dataclass 146 @dataclass
33 class Lhs: 147 class Lhs:
34 graph: Graph 148 graph: Graph
35 149
36 def __post_init__(self): 150 def __post_init__(self):
37 # do precomputation in here that's not specific to the workingSet 151 # do precomputation in here that's not specific to the workingSet
38 self.staticRuleStmts = Graph() 152 # self.staticRuleStmts = Graph()
39 self.nonStaticRuleStmts = Graph() 153 # self.nonStaticRuleStmts = Graph()
40 154
41 self.lhsBindables: Set[BindableTerm] = set() 155 # self.lhsBindables: Set[BindableTerm] = set()
42 self.lhsBnodes: Set[BNode] = set() 156 # self.lhsBnodes: Set[BNode] = set()
43 for ruleStmt in self.graph: 157 # for ruleStmt in self.graph:
44 varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] 158 # varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))]
45 self.lhsBindables.update(varsAndBnodesInStmt) 159 # self.lhsBindables.update(varsAndBnodesInStmt)
46 self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) 160 # self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode))
47 if not varsAndBnodesInStmt: 161 # if not varsAndBnodesInStmt:
48 self.staticRuleStmts.add(ruleStmt) 162 # self.staticRuleStmts.add(ruleStmt)
49 else: 163 # else:
50 self.nonStaticRuleStmts.add(ruleStmt) 164 # self.nonStaticRuleStmts.add(ruleStmt)
51 165
52 self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts) 166 # self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts)
53 167
54 self.evaluations = list(Evaluation.findEvals(self.graph)) 168 self.evaluations = list(Evaluation.findEvals(self.graph))
55 169
56 def __repr__(self): 170 def __repr__(self):
57 return f"Lhs({graphDump(self.graph)})" 171 return f"Lhs({graphDump(self.graph)})"
58 172
59 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: 173 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
60 """bindings that fit the LHS of a rule, using statements from workingSet and functions 174 """bindings that fit the LHS of a rule, using statements from workingSet and functions
61 from LHS""" 175 from LHS"""
62 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') 176 log.debug(f'{INDENT*4} build new StmtLooper stack')
63 stats['findCandidateBindingsCalls'] += 1 177
64 178 stmtStack: List[StmtLooper] = []
65 if not self._allStaticStatementsMatch(knownTrue): 179 try:
66 stats['findCandidateBindingEarlyExits'] += 1 180 prev: Optional[StmtLooper] = None
181 for s in sorted(self.graph): # order of this matters! :(
182 stmtStack.append(StmtLooper(s, prev, knownTrue))
183 prev = stmtStack[-1]
184 except NoOptions:
185 log.debug(f'{INDENT*5} no options; 0 bindings')
67 return 186 return
68 187
69 for binding in self._possibleBindings(knownTrue, stats): 188 log.debug(f'{INDENT*5} initial odometer:')
70 log.debug('') 189 for l in stmtStack:
71 log.debug(f'{INDENT*4}*trying {binding.binding}') 190 log.debug(f'{INDENT*6} {l}')
72 191
73 if not binding.verify(knownTrue): 192 if any(ring.pastEnd() for ring in stmtStack):
74 log.debug(f'{INDENT*4} this binding did not verify') 193 log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}')
75 stats['permCountFailingVerify'] += 1 194
76 continue 195 raise NoOptions()
77 196 sl = stmtStack[-1]
78 stats['permCountSucceeding'] += 1 197 iterCount = 0
79 yield binding 198 while True:
199 iterCount += 1
200 if iterCount > 10:
201 raise ValueError('stuck')
202
203 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
204
205 log.debug(f'{INDENT*5} <<<')
206 yield BoundLhs(self, sl.currentBinding())
207 log.debug(f'{INDENT*5} >>>')
208
209 log.debug(f'{INDENT*5} odometer:')
210 for l in stmtStack:
211 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
212
213 done = self._advanceAll(stmtStack)
214
215 log.debug(f'{INDENT*5} odometer after ({done=}):')
216 for l in stmtStack:
217 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
218
219 log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
220 if done:
221 break
222
223 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
224 carry = True # 1st elem always must advance
225 for i, ring in enumerate(stmtStack):
226 # unlike normal odometer, advancing any earlier ring could invalidate later ones
227 if carry:
228 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance')
229 ring.advance()
230 carry = False
231 if ring.pastEnd():
232 if ring is stmtStack[-1]:
233 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} says we done')
234 return True
235 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart')
236 ring.restart()
237 carry = True
238 return False
80 239
81 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: 240 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool:
82 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's 241 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's
83 # static stmt is matched by a non-static stmt in the rule itself 242 # static stmt is matched by a non-static stmt in the rule itself
84 for ruleStmt in self.staticRuleStmts: 243 for ruleStmt in self.staticRuleStmts:
160 log.debug(f'{INDENT*4} {v!r} could be:') 319 log.debug(f'{INDENT*4} {v!r} could be:')
161 for val in vals: 320 for val in vals:
162 log.debug(f'{INDENT*5}{val!r}') 321 log.debug(f'{INDENT*5}{val!r}')
163 322
164 323
165 # @dataclass
166 # class CandidateTermMatches:
167 # """lazily find the possible matches for this term"""
168 # terms: List[BindableTerm]
169 # lhs: Lhs
170 # knownTrue: Graph
171 # boundSoFar: CandidateBinding
172
173 # def __post_init__(self):
174 # self.results: List[Node] = [] # we have to be able to repeat the results
175
176 # res: Set[Node] = set()
177 # for trueStmt in self.knownTrue: # all bound
178 # lStmts = list(self.lhsStmtsContainingTerm())
179 # log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}')
180 # for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False):
181 # log.debug(f'{INDENT*4} {pat=}')
182 # implied = self._stmtImplies(pat, trueStmt)
183 # if implied is not None:
184 # res.add(implied)
185 # self.results = list(res)
186 # # self.results.sort()
187
188 # log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}')
189
190 # def lhsStmtsContainingTerm(self):
191 # # lhs could precompute this
192 # for lhsStmt in self.lhs.graph:
193 # if self.term in lhsStmt:
194 # yield lhsStmt
195
196 # def __iter__(self):
197 # return iter(self.results)
198
199
200 @dataclass 324 @dataclass
201 class BoundLhs: 325 class BoundLhs:
202 lhs: Lhs 326 lhs: Lhs
203 binding: CandidateBinding 327 binding: CandidateBinding
204 328
205 def __post_init__(self): 329 def __post_init__(self):
206 self.usedByFuncs = Graph() 330 self.usedByFuncs = Graph()
207 self._applyFunctions() 331 # self._applyFunctions()
208 332
209 def lhsStmtsWithoutEvals(self): 333 def lhsStmtsWithoutEvals(self):
210 for stmt in self.lhs.graph: 334 for stmt in self.lhs.graph:
211 if stmt in self.usedByFuncs: 335 if stmt in self.usedByFuncs:
212 continue 336 continue
261 385
262 @dataclass 386 @dataclass
263 class Rule: 387 class Rule:
264 lhsGraph: Graph 388 lhsGraph: Graph
265 rhsGraph: Graph 389 rhsGraph: Graph
266 390
267 def __post_init__(self): 391 def __post_init__(self):
268 self.lhs = Lhs(self.lhsGraph) 392 self.lhs = Lhs(self.lhsGraph)
393 #
394 self.rhsBnodeMap = {}
269 395
270 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): 396 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict):
271 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): 397 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
272 log.debug(f'{INDENT*3} rule has a working binding:') 398 log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
273 399
274 for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): 400 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
275 log.debug(f'{INDENT*4} adding {lhsBoundStmt=}') 401 existingRhsBnodes = set()
276 workingSet.add(lhsBoundStmt) 402 for stmt in self.rhsGraph:
403 for t in stmt:
404 if isinstance(t, BNode):
405 existingRhsBnodes.add(t)
406 # if existingRhsBnodes:
407 # log.debug(f'{INDENT*6} mapping rhs bnodes {existingRhsBnodes} to new ones')
408
409 for b in existingRhsBnodes:
410
411 key = tuple(sorted(bound.binding.binding.items())), b
412 self.rhsBnodeMap.setdefault(key, BNode())
413
414
415 bound.binding.addNewBindings(CandidateBinding({b: self.rhsBnodeMap[key]}))
416
417 # for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()):
418 # log.debug(f'{INDENT*6} adding to workingSet {lhsBoundStmt=}')
419 # workingSet.add(lhsBoundStmt)
420 # log.debug(f'{INDENT*6} rhsGraph is good: {list(self.rhsGraph)}')
421
277 for newStmt in bound.binding.apply(self.rhsGraph): 422 for newStmt in bound.binding.apply(self.rhsGraph):
278 log.debug(f'{INDENT*4} adding {newStmt=}') 423 # log.debug(f'{INDENT*6} adding {newStmt=}')
279 workingSet.add(newStmt) 424 workingSet.add(newStmt)
280 implied.add(newStmt) 425 implied.add(newStmt)
281 426
282 427
283 class Inference: 428 class Inference:
348 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') 493 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
349 494
350 495
351 def graphDump(g: Union[Graph, List[Triple]]): 496 def graphDump(g: Union[Graph, List[Triple]]):
352 if not isinstance(g, Graph): 497 if not isinstance(g, Graph):
353 log.warning(f"it's a {type(g)}")
354 g2 = Graph() 498 g2 = Graph()
355 g2 += g 499 g2 += g
356 g = g2 500 g = g2
357 g.bind('', ROOM) 501 g.bind('', ROOM)
358 g.bind('ex', Namespace('http://example.com/')) 502 g.bind('ex', Namespace('http://example.com/'))