comparison service/mqtt_to_rdf/inference.py @ 1651:20474ad4968e

WIP - functions are broken as i move most layers to work in Chunks not Triples A Chunk is a Triple plus any rdf lists.
author drewp@bigasterisk.com
date Sat, 18 Sep 2021 23:57:20 -0700
parents 2061df259224
children dddfa09ea0b9
comparison
equal deleted inserted replaced
1650:2061df259224 1651:20474ad4968e
5 import itertools 5 import itertools
6 import logging 6 import logging
7 import time 7 import time
8 from collections import defaultdict 8 from collections import defaultdict
9 from dataclasses import dataclass 9 from dataclasses import dataclass
10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) 10 from typing import Dict, Iterator, List, Optional, Sequence, Tuple, Union, cast
11 11
12 from prometheus_client import Histogram, Summary 12 from prometheus_client import Histogram, Summary
13 from rdflib import RDF, BNode, Graph, Literal, Namespace 13 from rdflib import RDF, BNode, Graph, Namespace
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate 14 from rdflib.graph import ConjunctiveGraph
15 from rdflib.term import Node, URIRef, Variable 15 from rdflib.term import Node, URIRef, Variable
16 16
17 from candidate_binding import BindingConflict, CandidateBinding 17 from candidate_binding import BindingConflict, CandidateBinding
18 from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple 18 from inference_types import BindingUnknown, Inconsistent, Triple
19 from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs, rulePredicates 19 from lhs_evaluation import functionsFor
20 from rdf_debug import graphDump 20 from rdf_debug import graphDump
21 from stmt_chunk import Chunk, ChunkedGraph, applyChunky
21 22
22 log = logging.getLogger('infer') 23 log = logging.getLogger('infer')
23 INDENT = ' ' 24 INDENT = ' '
24 25
25 INFER_CALLS = Summary('inference_infer_calls', 'calls') 26 INFER_CALLS = Summary('inference_infer_calls', 'calls')
28 ROOM = Namespace("http://projects.bigasterisk.com/room/") 29 ROOM = Namespace("http://projects.bigasterisk.com/room/")
29 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') 30 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
30 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') 31 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
31 32
32 33
33 def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]:
34 return (
35 None if isinstance(stmt[0], (Variable, BNode)) else stmt[0],
36 None if isinstance(stmt[1], (Variable, BNode)) else stmt[1],
37 None if isinstance(stmt[2], (Variable, BNode)) else stmt[2],
38 )
39
40
41 class NoOptions(ValueError): 34 class NoOptions(ValueError):
42 """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply""" 35 """ChunkLooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
43 36
44 37
45 class Inconsistent(ValueError): 38 _chunkLooperShortId = itertools.count()
46 """adding this stmt would be inconsistent with an existing binding"""
47
48
49 _stmtLooperShortId = itertools.count()
50 39
51 40
52 @dataclass 41 @dataclass
53 class StmtLooper: 42 class ChunkLooper:
54 """given one LHS stmt, iterate through the possible matches for it, 43 """given one LHS Chunk, iterate through the possible matches for it,
55 returning what bindings they would imply. Only distinct bindings are 44 returning what bindings they would imply. Only distinct bindings are
56 returned. The bindings build on any `prev` StmtLooper's results. 45 returned. The bindings build on any `prev` ChunkLooper's results.
57 46
58 This iterator is restartable.""" 47 This iterator is restartable."""
59 lhsStmt: Triple 48 lhsChunk: Chunk
60 prev: Optional['StmtLooper'] 49 prev: Optional['ChunkLooper']
61 workingSet: ReadOnlyWorkingSet 50 workingSet: 'ChunkedGraph'
62 parent: 'Lhs' # just for lhs.graph, really 51 parent: 'Lhs' # just for lhs.graph, really
63 52
64 def __repr__(self): 53 def __repr__(self):
65 return f'StmtLooper{self._shortId}{"<pastEnd>" if self.pastEnd() else ""})' 54 return f'{self.__class__.__name__}{self._shortId}{"<pastEnd>" if self.pastEnd() else ""}'
66 55
67 def __post_init__(self): 56 def __post_init__(self):
68 self._shortId = next(_stmtLooperShortId) 57 self._shortId = next(_chunkLooperShortId)
69 self._myWorkingSetMatches = self._myMatches(self.workingSet) 58 self._myWorkingSetMatches = self.lhsChunk.myMatches(self.workingSet)
70 59
71 self._current = CandidateBinding({}) 60 self._current = CandidateBinding({})
72 self._pastEnd = False 61 self._pastEnd = False
73 self._seenBindings: List[CandidateBinding] = [] 62 self._seenBindings: List[CandidateBinding] = []
74 63
75 log.debug(f'introducing {self!r}({graphDump([self.lhsStmt])})') 64 log.debug(f'{INDENT*6} introducing {self!r}({self.lhsChunk}, {self._myWorkingSetMatches=})')
76 65
77 self.restart() 66 self.restart()
78
79 def _myMatches(self, g: Graph) -> List[Triple]:
80 template = stmtTemplate(self.lhsStmt)
81
82 stmts = sorted(cast(Iterator[Triple], list(g.triples(template))))
83 # plus new lhs possibilties...
84 # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}')
85
86 return stmts
87 67
88 def _prevBindings(self) -> CandidateBinding: 68 def _prevBindings(self) -> CandidateBinding:
89 if not self.prev or self.prev.pastEnd(): 69 if not self.prev or self.prev.pastEnd():
90 return CandidateBinding({}) 70 return CandidateBinding({})
91 71
94 def advance(self): 74 def advance(self):
95 """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode""" 75 """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode"""
96 if self._pastEnd: 76 if self._pastEnd:
97 raise NotImplementedError('need restart') 77 raise NotImplementedError('need restart')
98 log.debug('') 78 log.debug('')
99 augmentedWorkingSet: Sequence[Triple] = [] 79 augmentedWorkingSet: Sequence[Chunk] = []
100 if self.prev is None: 80 if self.prev is None:
101 augmentedWorkingSet = self._myWorkingSetMatches 81 augmentedWorkingSet = self._myWorkingSetMatches
102 else: 82 else:
103 augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, 83 augmentedWorkingSet = list(
104 returnBoundStatementsOnly=False)) 84 applyChunky(self.prev.currentBinding(), self._myWorkingSetMatches, returnBoundStatementsOnly=False))
105 85
106 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') 86 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}')
107 87
108 if self._advanceWithPlainMatches(augmentedWorkingSet): 88 if self._advanceWithPlainMatches(augmentedWorkingSet):
109 return 89 return
112 return 92 return
113 93
114 log.debug(f'{INDENT*6} {self} is past end') 94 log.debug(f'{INDENT*6} {self} is past end')
115 self._pastEnd = True 95 self._pastEnd = True
116 96
117 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: 97 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Chunk]) -> bool:
118 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') 98 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
119 for s in augmentedWorkingSet: 99 for s in augmentedWorkingSet:
120 log.debug(f'{INDENT*7} {s}') 100 log.debug(f'{INDENT*7} {s}')
121 101
122 for stmt in augmentedWorkingSet: 102 for chunk in augmentedWorkingSet:
123 try: 103 try:
124 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) 104 outBinding = self.lhsChunk.totalBindingIfThisStmtWereTrue(self._prevBindings(), chunk)
125 except Inconsistent: 105 except Inconsistent:
126 log.debug(f'{INDENT*7} StmtLooper{self._shortId} - {stmt} would be inconsistent with prev bindings') 106 log.debug(f'{INDENT*7} ChunkLooper{self._shortId} - {chunk} would be inconsistent with prev bindings')
127 continue 107 continue
128 108
129 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') 109 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
130 if outBinding not in self._seenBindings: 110 if outBinding not in self._seenBindings:
131 self._seenBindings.append(outBinding.copy()) 111 self._seenBindings.append(outBinding.copy())
133 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') 113 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}')
134 return True 114 return True
135 return False 115 return False
136 116
137 def _advanceWithFunctions(self) -> bool: 117 def _advanceWithFunctions(self) -> bool:
138 pred: Node = self.lhsStmt[1] 118 pred: Node = self.lhsChunk.predicate
139 if not isinstance(pred, URIRef): 119 if not isinstance(pred, URIRef):
140 raise NotImplementedError 120 raise NotImplementedError
141 121
142 for functionType in functionsFor(pred): 122 for functionType in functionsFor(pred):
143 fn = functionType(self.lhsStmt, self.parent.graph) 123 fn = functionType(self.lhsChunk, self.parent.graph)
144 try: 124 try:
145 out = fn.bind(self._prevBindings()) 125 out = fn.bind(self._prevBindings())
146 except BindingUnknown: 126 except BindingUnknown:
147 pass 127 pass
148 else: 128 else:
166 boundOperands.append(pb.applyTerm(op)) 146 boundOperands.append(pb.applyTerm(op))
167 else: 147 else:
168 boundOperands.append(op) 148 boundOperands.append(op)
169 return boundOperands 149 return boundOperands
170 150
171 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
172 outBinding = self._prevBindings().copy()
173 for rt, ct in zip(self.lhsStmt, newStmt):
174 if isinstance(rt, (Variable, BNode)):
175 if outBinding.contains(rt) and outBinding.applyTerm(rt) != ct:
176 msg = f'{rt=} {ct=} {outBinding=}' if log.isEnabledFor(logging.DEBUG) else ''
177 raise Inconsistent(msg)
178 outBinding.addNewBindings(CandidateBinding({rt: ct}))
179 return outBinding
180
181 def currentBinding(self) -> CandidateBinding: 151 def currentBinding(self) -> CandidateBinding:
182 if self.pastEnd(): 152 if self.pastEnd():
183 raise NotImplementedError() 153 raise NotImplementedError()
184 return self._current 154 return self._current
185 155
194 raise NoOptions() 164 raise NoOptions()
195 165
196 166
197 @dataclass 167 @dataclass
198 class Lhs: 168 class Lhs:
199 graph: Graph # our full LHS graph, as input. See below for the statements partitioned into groups. 169 graph: ChunkedGraph # our full LHS graph, as input. See below for the statements partitioned into groups.
200 170
201 def __post_init__(self): 171 def __post_init__(self):
202 172
203 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) 173 self.myPreds = self.graph.allPredicatesExceptFunctions()
204
205 stmtsToMatch = list(self.graph - usedByFuncs)
206 self.staticStmts = []
207 self.patternStmts = []
208 for st in stmtsToMatch:
209 if all(isinstance(term, (URIRef, Literal)) for term in st):
210 self.staticStmts.append(st)
211 else:
212 self.patternStmts.append(st)
213
214 # sort them by variable dependencies; don't just try all perms!
215 def lightSortKey(stmt): # Not this.
216 (s, p, o) = stmt
217 return p in rulePredicates(), p, s, o
218
219 self.patternStmts.sort(key=lightSortKey)
220
221 self.myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
222 self.myPreds -= rulePredicates()
223 self.myPreds -= {RDF.first, RDF.rest}
224 self.myPreds = set(self.myPreds)
225 174
226 def __repr__(self): 175 def __repr__(self):
227 return f"Lhs({graphDump(self.graph)})" 176 return f"Lhs({self.graph!r})"
228 177
229 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']: 178 def findCandidateBindings(self, knownTrue: ChunkedGraph, stats, ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
230 """bindings that fit the LHS of a rule, using statements from workingSet and functions 179 """bindings that fit the LHS of a rule, using statements from workingSet and functions
231 from LHS""" 180 from LHS"""
232 if self.graph.__len__() == 0: 181 if not self.graph:
233 # special case- no LHS! 182 # special case- no LHS!
234 yield BoundLhs(self, CandidateBinding({})) 183 yield BoundLhs(self, CandidateBinding({}))
235 return 184 return
236 185
237 if self._checkPredicateCounts(knownTrue): 186 if self._checkPredicateCounts(knownTrue):
238 stats['_checkPredicateCountsCulls'] += 1 187 stats['_checkPredicateCountsCulls'] += 1
239 return 188 return
240 189
241 if not all(st in knownTrue for st in self.staticStmts): 190 if not all(ch in knownTrue for ch in self.graph.staticChunks):
242 stats['staticStmtCulls'] += 1 191 stats['staticStmtCulls'] += 1
243 return 192 return
244 193
245 if len(self.patternStmts) == 0: 194 if not self.graph.patternChunks:
246 # static only 195 # static only
247 yield BoundLhs(self, CandidateBinding({})) 196 yield BoundLhs(self, CandidateBinding({}))
248 return 197 return
249 198
250 log.debug(f'{INDENT*4} build new StmtLooper stack') 199 log.debug(f'{INDENT*4} build new ChunkLooper stack')
251 200
252 try: 201 try:
253 stmtStack = self._assembleRings(knownTrue, stats) 202 chunkStack = self._assembleRings(knownTrue, stats)
254 except NoOptions: 203 except NoOptions:
255 log.debug(f'{INDENT*5} start up with no options; 0 bindings') 204 log.debug(f'{INDENT*5} start up with no options; 0 bindings')
256 return 205 return
257 self._debugStmtStack('initial odometer', stmtStack) 206 self._debugChunkStack('initial odometer', chunkStack)
258 self._assertAllRingsAreValid(stmtStack) 207 self._assertAllRingsAreValid(chunkStack)
259 208
260 lastRing = stmtStack[-1] 209 lastRing = chunkStack[-1]
261 iterCount = 0 210 iterCount = 0
262 while True: 211 while True:
263 iterCount += 1 212 iterCount += 1
264 if iterCount > ruleStatementsIterationLimit: 213 if iterCount > ruleStatementsIterationLimit:
265 raise ValueError('rule too complex') 214 raise ValueError('rule too complex')
266 215
267 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') 216 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
268 217
269 yield BoundLhs(self, lastRing.currentBinding()) 218 yield BoundLhs(self, lastRing.currentBinding())
270 219
271 self._debugStmtStack('odometer', stmtStack) 220 self._debugChunkStack('odometer', chunkStack)
272 221
273 done = self._advanceAll(stmtStack) 222 done = self._advanceAll(chunkStack)
274 223
275 self._debugStmtStack('odometer after ({done=})', stmtStack) 224 self._debugChunkStack(f'odometer after ({done=})', chunkStack)
276 225
277 log.debug(f'{INDENT*4} ^^ findCandBindings iteration done') 226 log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
278 if done: 227 if done:
279 break 228 break
280 229
281 def _debugStmtStack(self, label, stmtStack): 230 def _debugChunkStack(self, label: str, chunkStack: List[ChunkLooper]):
282 log.debug(f'{INDENT*5} {label}:') 231 log.debug(f'{INDENT*5} {label}:')
283 for l in stmtStack: 232 for l in chunkStack:
284 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') 233 log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
285 234
286 def _checkPredicateCounts(self, knownTrue): 235 def _checkPredicateCounts(self, knownTrue):
287 """raise NoOptions quickly in some cases""" 236 """raise NoOptions quickly in some cases"""
288 237
289 if any((None, p, None) not in knownTrue for p in self.myPreds): 238 if self.graph.noPredicatesAppear(self.myPreds):
239 log.info(f'{INDENT*2} checkPredicateCounts does cull because not all {self.myPreds=} are in knownTrue')
290 return True 240 return True
291 log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue') 241 log.info(f'{INDENT*2} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
292 return False 242 return False
293 243
294 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet, stats) -> List[StmtLooper]: 244 def _assembleRings(self, knownTrue: ChunkedGraph, stats) -> List[ChunkLooper]:
295 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all 245 """make ChunkLooper for each stmt in our LHS graph, but do it in a way that they all
296 start out valid (or else raise NoOptions)""" 246 start out valid (or else raise NoOptions)"""
297 247
298 log.info(f'{INDENT*2} stats={dict(stats)}') 248 log.info(f'{INDENT*2} stats={dict(stats)}')
299 log.info(f'{INDENT*2} taking permutations of {len(self.patternStmts)=}') 249 log.info(f'{INDENT*2} taking permutations of {len(self.graph.patternChunks)=}')
300 for i, perm in enumerate(itertools.permutations(self.patternStmts)): 250 for i, perm in enumerate(itertools.permutations(self.graph.patternChunks)):
301 stmtStack: List[StmtLooper] = [] 251 stmtStack: List[ChunkLooper] = []
302 prev: Optional[StmtLooper] = None 252 prev: Optional[ChunkLooper] = None
303 if log.isEnabledFor(logging.DEBUG): 253 if log.isEnabledFor(logging.DEBUG):
304 log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') 254 log.debug(f'{INDENT*5} [perm {i}] try stmts in this order: {" -> ".join(repr(p) for p in perm)}')
305 255
306 for s in perm: 256 for s in perm:
307 try: 257 try:
308 elem = StmtLooper(s, prev, knownTrue, parent=self) 258 elem = ChunkLooper(s, prev, knownTrue, parent=self)
309 except NoOptions: 259 except NoOptions:
310 log.debug(f'{INDENT*6} permutation didnt work, try another') 260 log.debug(f'{INDENT*6} permutation didnt work, try another')
311 break 261 break
312 stmtStack.append(elem) 262 stmtStack.append(elem)
313 prev = stmtStack[-1] 263 prev = stmtStack[-1]
314 else: 264 else:
315 return stmtStack 265 return stmtStack
316 if i > 5000: 266 if i > 5000:
317 raise NotImplementedError(f'trying too many permutations {len(self.patternStmts)=}') 267 raise NotImplementedError(f'trying too many permutations {len(self.graph.patternChunks)=}')
318 268
319 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything') 269 log.debug(f'{INDENT*6} no perms worked- rule cannot match anything')
320 raise NoOptions() 270 raise NoOptions()
321 271
322 def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: 272 def _advanceAll(self, stmtStack: List[ChunkLooper]) -> bool:
323 carry = True # 1st elem always must advance 273 carry = True # 1st elem always must advance
324 for i, ring in enumerate(stmtStack): 274 for i, ring in enumerate(stmtStack):
325 # unlike normal odometer, advancing any earlier ring could invalidate later ones 275 # unlike normal odometer, advancing any earlier ring could invalidate later ones
326 if carry: 276 if carry:
327 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance') 277 log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance')
352 class Rule: 302 class Rule:
353 lhsGraph: Graph 303 lhsGraph: Graph
354 rhsGraph: Graph 304 rhsGraph: Graph
355 305
356 def __post_init__(self): 306 def __post_init__(self):
357 self.lhs = Lhs(self.lhsGraph) 307 self.lhs = Lhs(ChunkedGraph(self.lhsGraph, functionsFor))
358 # 308 #
359 self.rhsBnodeMap = {} 309 self.rhsBnodeMap = {}
360 310
361 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit): 311 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, ruleStatementsIterationLimit):
362 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats, ruleStatementsIterationLimit): 312 # this does not change for the current applyRule call. The rule will be
313 # tried again in an outer loop, in case it can produce more.
314 workingSetChunked = ChunkedGraph(workingSet, functionsFor)
315
316 for bound in self.lhs.findCandidateBindings(workingSetChunked, stats, ruleStatementsIterationLimit):
363 log.debug(f'{INDENT*5} +rule has a working binding: {bound}') 317 log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
364 318
365 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do 319 # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
366 existingRhsBnodes = set() 320 existingRhsBnodes = set()
367 for stmt in self.rhsGraph: 321 for stmt in self.rhsGraph: