comparison service/mqtt_to_rdf/inference/inference.py @ 1727:23e6154e6c11

file moves
author drewp@bigasterisk.com
date Tue, 20 Jun 2023 23:26:24 -0700
parents service/mqtt_to_rdf/inference.py@88f6e9bf69d1
children
comparison
equal deleted inserted replaced
1726:7d3797ed6681 1727:23e6154e6c11
1 """
2 copied from reasoning 2021-08-29. probably same api. should
3 be able to lib/ this out
4 """
5 import itertools
6 import logging
7 import time
8 from collections import defaultdict
9 from dataclasses import dataclass
10 from pathlib import Path
11 from typing import Dict, Iterator, List, Optional, Tuple, Union, cast
12
13 from prometheus_client import Histogram, Summary
14 from rdflib import Graph, Namespace
15 from rdflib.graph import ConjunctiveGraph
16 from rdflib.term import Node, URIRef
17
18 from inference.candidate_binding import CandidateBinding
19 from inference.inference_types import (BindingUnknown, Inconsistent, RhsBnode, RuleUnboundBnode, Triple, WorkingSetBnode)
20 from inference.lhs_evaluation import functionsFor
21 from inference.rdf_debug import graphDump
22 from inference.stmt_chunk import (AlignedRuleChunk, Chunk, ChunkedGraph, applyChunky)
23 from inference.structured_log import StructuredLog
24
25 log = logging.getLogger('infer')
26 odolog = logging.getLogger('infer.odo') # the "odometer" logic
27 ringlog = logging.getLogger('infer.ring') # for ChunkLooper
28
29 INDENT = ' '
30
31 INFER_CALLS = Summary('inference_infer_calls', 'calls')
32 INFER_GRAPH_SIZE = Histogram('inference_graph_size', 'statements', buckets=[2**x for x in range(2, 20, 2)])
33
34 ROOM = Namespace("http://projects.bigasterisk.com/room/")
35 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
36 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
37
38
39 class NoOptions(ValueError):
40 """ChunkLooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
41
42
43 def debug(logger, slog: Optional[StructuredLog], msg):
44 logger.debug(msg)
45 if slog:
46 slog.say(msg)
47
48
49 _chunkLooperShortId = itertools.count()
50
51
52 @dataclass
53 class ChunkLooper:
54 """given one LHS Chunk, iterate through the possible matches for it,
55 returning what bindings they would imply. Only distinct bindings are
56 returned. The bindings build on any `prev` ChunkLooper's results.
57
58 In the odometer metaphor used below, this is one of the rings.
59
60 This iterator is restartable."""
61 lhsChunk: Chunk
62 prev: Optional['ChunkLooper']
63 workingSet: 'ChunkedGraph'
64 slog: Optional[StructuredLog]
65
66 def __repr__(self):
67 return f'{self.__class__.__name__}{self._shortId}{"<pastEnd>" if self.pastEnd() else ""}'
68
69 def __post_init__(self):
70 self._shortId = next(_chunkLooperShortId)
71 self._alignedMatches = list(self.lhsChunk.ruleMatchesFrom(self.workingSet))
72 del self.workingSet
73
74 # only ours- do not store prev, since it could change without us
75 self._current = CandidateBinding({})
76 self.currentSourceChunk: Optional[Chunk] = None # for debugging only
77 self._pastEnd = False
78 self._seenBindings: List[CandidateBinding] = [] # combined bindings (up to our ring) that we've returned
79
80 if ringlog.isEnabledFor(logging.DEBUG):
81 ringlog.debug('')
82 msg = f'{INDENT*6} introducing {self!r}({self.lhsChunk}, {self._alignedMatches=})'
83 msg = msg.replace('AlignedRuleChunk', f'\n{INDENT*12}AlignedRuleChunk')
84 ringlog.debug(msg)
85
86 self.restart()
87
88 def _prevBindings(self) -> CandidateBinding:
89 if not self.prev or self.prev.pastEnd():
90 return CandidateBinding({})
91
92 return self.prev.currentBinding()
93
94 def advance(self):
95 """update _current to a new set of valid bindings we haven't seen (since
96 last restart), or go into pastEnd mode. Note that _current is just our
97 contribution, but returned valid bindings include all prev rings."""
98 if self._pastEnd:
99 raise NotImplementedError('need restart')
100 ringlog.debug('')
101 debug(ringlog, self.slog, f'{INDENT*6} --> {self}.advance start:')
102
103 self._currentIsFromFunc = None
104 augmentedWorkingSet: List[AlignedRuleChunk] = []
105 if self.prev is None:
106 augmentedWorkingSet = self._alignedMatches
107 else:
108 augmentedWorkingSet = list(applyChunky(self.prev.currentBinding(), self._alignedMatches))
109
110 if self._advanceWithPlainMatches(augmentedWorkingSet):
111 debug(ringlog, self.slog, f'{INDENT*6} <-- {self}.advance finished with plain matches')
112 return
113
114 if self._advanceWithFunctions():
115 debug(ringlog, self.slog, f'{INDENT*6} <-- {self}.advance finished with function matches')
116 return
117
118 debug(ringlog, self.slog, f'{INDENT*6} <-- {self}.advance had nothing and is now past end')
119 self._pastEnd = True
120
121 def _advanceWithPlainMatches(self, augmentedWorkingSet: List[AlignedRuleChunk]) -> bool:
122 # if augmentedWorkingSet:
123 # debug(ringlog, self.slog, f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
124 # for s in augmentedWorkingSet:
125 # debug(ringlog, self.slog, f'{INDENT*8} {s}')
126
127 for aligned in augmentedWorkingSet:
128 try:
129 newBinding = aligned.newBindingIfMatched(self._prevBindings())
130 except Inconsistent as exc:
131 debug(ringlog, self.slog, f'{INDENT*7} ChunkLooper{self._shortId} - {aligned} would be inconsistent with prev bindings ({exc})')
132 continue
133
134 if self._testAndKeepNewBinding(newBinding, aligned.workingSetChunk):
135 return True
136 return False
137
138 def _advanceWithFunctions(self) -> bool:
139 pred: Node = self.lhsChunk.predicate
140 if not isinstance(pred, URIRef):
141 raise NotImplementedError
142
143 for functionType in functionsFor(pred):
144 fn = functionType(self.lhsChunk)
145 # debug(ringlog, self.slog, f'{INDENT*7} ChunkLooper{self._shortId} advanceWithFunctions, {functionType=}')
146
147 try:
148 log.debug(f'fn.bind {self._prevBindings()} ...')
149 #fullBinding = self._prevBindings().copy()
150 newBinding = fn.bind(self._prevBindings())
151 log.debug(f'...makes {newBinding=}')
152 except BindingUnknown:
153 pass
154 else:
155 if newBinding is not None:
156 self._currentIsFromFunc = fn
157 if self._testAndKeepNewBinding(newBinding, self.lhsChunk):
158 return True
159
160 return False
161
162 def _testAndKeepNewBinding(self, newBinding: CandidateBinding, sourceChunk: Chunk):
163 fullBinding: CandidateBinding = self._prevBindings().copy()
164 fullBinding.addNewBindings(newBinding)
165 isNew = fullBinding not in self._seenBindings
166
167 if ringlog.isEnabledFor(logging.DEBUG):
168 ringlog.debug(f'{INDENT*7} {self} considering {newBinding=} to make {fullBinding}. {isNew=}')
169 # if self.slog:
170 # self.slog.looperConsider(self, newBinding, fullBinding, isNew)
171
172 if isNew:
173 self._seenBindings.append(fullBinding.copy())
174 self._current = newBinding
175 self.currentSourceChunk = sourceChunk
176 return True
177 return False
178
179 def localBinding(self) -> CandidateBinding:
180 if self.pastEnd():
181 raise NotImplementedError()
182 return self._current
183
184 def currentBinding(self) -> CandidateBinding:
185 if self.pastEnd():
186 raise NotImplementedError()
187 together = self._prevBindings().copy()
188 together.addNewBindings(self._current)
189 return together
190
191 def pastEnd(self) -> bool:
192 return self._pastEnd
193
194 def restart(self):
195 try:
196 self._pastEnd = False
197 self._seenBindings = []
198 self.advance()
199 if self.pastEnd():
200 raise NoOptions()
201 finally:
202 debug(ringlog, self.slog, f'{INDENT*7} ChunkLooper{self._shortId} restarts: pastEnd={self.pastEnd()}')
203
204
205 @dataclass
206 class Lhs:
207 graph: ChunkedGraph # our full LHS graph, as input. See below for the statements partitioned into groups.
208
209 def __post_init__(self):
210
211 self.myPreds = self.graph.allPredicatesExceptFunctions()
212
213 def __repr__(self):
214 return f"Lhs({self.graph!r})"
215
216 def findCandidateBindings(self, knownTrue: ChunkedGraph, stats, slog: Optional[StructuredLog], ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
217 """distinct bindings that fit the LHS of a rule, using statements from
218 workingSet and functions from LHS"""
219 if not self.graph:
220 # special case- no LHS!
221 yield BoundLhs(self, CandidateBinding({}))
222 return
223
224 if self._checkPredicateCounts(knownTrue):
225 stats['_checkPredicateCountsCulls'] += 1
226 return
227
228 if not all(ch in knownTrue for ch in self.graph.staticChunks):
229 stats['staticStmtCulls'] += 1
230 return
231 # After this point we don't need to consider self.graph.staticChunks.
232
233 if not self.graph.patternChunks and not self.graph.chunksUsedByFuncs:
234 # static only
235 yield BoundLhs(self, CandidateBinding({}))
236 return
237
238 log.debug('')
239 try:
240 chunkStack = self._assembleRings(knownTrue, stats, slog)
241 except NoOptions:
242 ringlog.debug(f'{INDENT*5} start up with no options; 0 bindings')
243 return
244 log.debug('')
245 log.debug('')
246 self._debugChunkStack('time to spin: initial odometer is', chunkStack)
247
248 if slog:
249 slog.say('time to spin')
250 slog.odometer(chunkStack)
251 self._assertAllRingsAreValid(chunkStack)
252
253 lastRing = chunkStack[-1]
254 iterCount = 0
255 while True:
256 iterCount += 1
257 if iterCount > ruleStatementsIterationLimit:
258 raise ValueError('rule too complex')
259
260 log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
261
262 yield BoundLhs(self, lastRing.currentBinding())
263
264 # self._debugChunkStack('odometer', chunkStack)
265
266 done = self._advanceTheStack(chunkStack)
267
268 self._debugChunkStack(f'odometer after ({done=})', chunkStack)
269 if slog:
270 slog.odometer(chunkStack)
271
272 log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
273 if done:
274 break
275
276 def _debugChunkStack(self, label: str, chunkStack: List[ChunkLooper]):
277 odolog.debug(f'{INDENT*4} {label}:')
278 for i, l in enumerate(chunkStack):
279 odolog.debug(f'{INDENT*5} [{i}] {l} curbind={l.localBinding() if not l.pastEnd() else "<end>"}')
280
281 def _checkPredicateCounts(self, knownTrue):
282 """raise NoOptions quickly in some cases"""
283
284 if self.graph.noPredicatesAppear(self.myPreds):
285 log.debug(f'{INDENT*3} checkPredicateCounts does cull because not all {self.myPreds=} are in knownTrue')
286 return True
287 log.debug(f'{INDENT*3} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
288 return False
289
290 def _assembleRings(self, knownTrue: ChunkedGraph, stats, slog) -> List[ChunkLooper]:
291 """make ChunkLooper for each stmt in our LHS graph, but do it in a way that they all
292 start out valid (or else raise NoOptions). static chunks have already been confirmed."""
293
294 log.debug(f'{INDENT*4} stats={dict(stats)}')
295 odolog.debug(f'{INDENT*3} build new ChunkLooper stack')
296 chunks = list(self.graph.patternChunks.union(self.graph.chunksUsedByFuncs))
297 chunks.sort(key=None)
298 odolog.info(f' {INDENT*3} taking permutations of {len(chunks)=}')
299
300 permsTried = 0
301
302 for perm in self._partitionedGraphPermutations():
303 looperRings: List[ChunkLooper] = []
304 prev: Optional[ChunkLooper] = None
305 if odolog.isEnabledFor(logging.DEBUG):
306 odolog.debug(f'{INDENT*4} [perm {permsTried}] try rule chunks in this order: {" THEN ".join(repr(p) for p in perm)}')
307
308 for ruleChunk in perm:
309 try:
310 # These are getting rebuilt a lot which takes time. It would
311 # be nice if they could accept a changing `prev` order
312 # (which might already be ok).
313 looper = ChunkLooper(ruleChunk, prev, knownTrue, slog)
314 except NoOptions:
315 odolog.debug(f'{INDENT*5} permutation didnt work, try another')
316 break
317 looperRings.append(looper)
318 prev = looperRings[-1]
319 else:
320 # bug: At this point we've only shown that these are valid
321 # starting rings. The rules might be tricky enough that this
322 # permutation won't get us to the solution.
323 return looperRings
324 if permsTried > 50000:
325 raise NotImplementedError(f'trying too many permutations {len(chunks)=}')
326 permsTried += 1
327
328 stats['permsTried'] += permsTried
329 odolog.debug(f'{INDENT*5} no perms worked- rule cannot match anything')
330 raise NoOptions()
331
332 def _unpartitionedGraphPermutations(self) -> Iterator[Tuple[Chunk, ...]]:
333 for perm in itertools.permutations(sorted(list(self.graph.patternChunks.union(self.graph.chunksUsedByFuncs)))):
334 yield perm
335
336 def _partitionedGraphPermutations(self) -> Iterator[Tuple[Chunk, ...]]:
337 """always puts function chunks after pattern chunks
338
339 (and, if we cared, static chunks could go before that. Currently they're
340 culled out elsewhere, but that's done as a special case)
341 """
342 tupleOfNoChunks: Tuple[Chunk, ...] = ()
343 pats = sorted(self.graph.patternChunks)
344 funcs = sorted(self.graph.chunksUsedByFuncs)
345 for patternPart in itertools.permutations(pats) if pats else [tupleOfNoChunks]:
346 for funcPart in itertools.permutations(funcs) if funcs else [tupleOfNoChunks]:
347 perm = patternPart + funcPart
348 yield perm
349
350 def _advanceTheStack(self, looperRings: List[ChunkLooper]) -> bool:
351 toRestart: List[ChunkLooper] = []
352 pos = len(looperRings) - 1
353 while True:
354 looperRings[pos].advance()
355 if looperRings[pos].pastEnd():
356 if pos == 0:
357 return True
358 toRestart.append(looperRings[pos])
359 pos -= 1
360 else:
361 break
362 for ring in reversed(toRestart):
363 ring.restart()
364 return False
365
366 def _assertAllRingsAreValid(self, looperRings):
367 if any(ring.pastEnd() for ring in looperRings): # this is an unexpected debug assertion
368 odolog.warning(f'{INDENT*4} some rings started at pastEnd {looperRings}')
369 raise NoOptions()
370
371
372 @dataclass
373 class BoundLhs:
374 lhs: Lhs
375 binding: CandidateBinding
376
377
378 @dataclass
379 class Rule:
380 lhsGraph: Graph
381 rhsGraph: Graph
382
383 def __post_init__(self):
384 self.lhs = Lhs(ChunkedGraph(self.lhsGraph, RuleUnboundBnode, functionsFor))
385
386 self.maps = {}
387
388 self.rhsGraphConvert: List[Triple] = []
389 for s, p, o in self.rhsGraph:
390 from rdflib import BNode
391 if isinstance(s, BNode):
392 s = RhsBnode(s)
393 if isinstance(p, BNode):
394 p = RhsBnode(p)
395 if isinstance(o, BNode):
396 o = RhsBnode(o)
397 self.rhsGraphConvert.append((s, p, o))
398
399 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, slog: Optional[StructuredLog], ruleStatementsIterationLimit):
400 # this does not change for the current applyRule call. The rule will be
401 # tried again in an outer loop, in case it can produce more.
402 workingSetChunked = ChunkedGraph(workingSet, WorkingSetBnode, functionsFor)
403
404 for bound in self.lhs.findCandidateBindings(workingSetChunked, stats, slog, ruleStatementsIterationLimit):
405 if slog:
406 slog.foundBinding(bound)
407 log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
408
409 newStmts = self.generateImpliedFromRhs(bound.binding)
410
411 for newStmt in newStmts:
412 # log.debug(f'{INDENT*6} adding {newStmt=}')
413 workingSet.add(newStmt)
414 implied.add(newStmt)
415
416 def generateImpliedFromRhs(self, binding: CandidateBinding) -> List[Triple]:
417
418 out: List[Triple] = []
419
420 # Each time the RHS is used (in a rule firing), its own BNodes (which
421 # are subtype RhsBnode) need to be turned into distinct ones. Note that
422 # bnodes that come from the working set should not be remapped.
423 rhsBnodeMap: Dict[RhsBnode, WorkingSetBnode] = {}
424
425 # but, the iteration loop could come back with the same bindings again
426 key = binding.key()
427 rhsBnodeMap = self.maps.setdefault(key, {})
428
429 for stmt in binding.apply(self.rhsGraphConvert):
430
431 outStmt: List[Node] = []
432
433 for t in stmt:
434 if isinstance(t, RhsBnode):
435 if t not in rhsBnodeMap:
436 rhsBnodeMap[t] = WorkingSetBnode()
437 t = rhsBnodeMap[t]
438
439 outStmt.append(t)
440
441 log.debug(f'{INDENT*6} rhs stmt {stmt} became {outStmt}')
442 out.append((outStmt[0], outStmt[1], outStmt[2]))
443
444 return out
445
446
447 @dataclass
448 class Inference:
449 rulesIterationLimit = 4
450 ruleStatementsIterationLimit = 5000
451
452 def __init__(self) -> None:
453 self.rules: List[Rule] = []
454 self._nonRuleStmts: List[Triple] = []
455
456 def setRules(self, g: ConjunctiveGraph):
457 self.rules = []
458 self._nonRuleStmts = []
459 for stmt in g:
460 if stmt[1] == LOG['implies']:
461 self.rules.append(Rule(stmt[0], stmt[2]))
462 else:
463 self._nonRuleStmts.append(stmt)
464
465 def nonRuleStatements(self) -> List[Triple]:
466 return self._nonRuleStmts
467
468 @INFER_CALLS.time()
469 def infer(self, graph: Graph, htmlLog: Optional[Path] = None):
470 """
471 returns new graph of inferred statements.
472 """
473 n = cast(int, graph.__len__())
474 INFER_GRAPH_SIZE.observe(n)
475 log.info(f'{INDENT*0} Begin inference of graph len={n} with rules len={len(self.rules)}:')
476 startTime = time.time()
477 stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
478
479 # everything that is true: the input graph, plus every rule conclusion we can make
480 workingSet = Graph()
481 workingSet += self._nonRuleStmts
482 workingSet += graph
483
484 # just the statements that came from RHS's of rules that fired.
485 implied = ConjunctiveGraph()
486
487 slog = StructuredLog(htmlLog) if htmlLog else None
488
489 rulesIterations = 0
490 delta = 1
491 stats['initWorkingSet'] = cast(int, workingSet.__len__())
492 if slog:
493 slog.workingSet = workingSet
494
495 while delta > 0:
496 log.debug('')
497 log.info(f'{INDENT*1}*iteration {rulesIterations}')
498 if slog:
499 slog.startIteration(rulesIterations)
500
501 delta = -len(implied)
502 self._iterateAllRules(workingSet, implied, stats, slog)
503 delta += len(implied)
504 rulesIterations += 1
505 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
506 if rulesIterations >= self.rulesIterationLimit:
507 raise ValueError(f"rule too complex after {rulesIterations=}")
508 stats['iterations'] = rulesIterations
509 stats['timeSpent'] = round(time.time() - startTime, 3)
510 stats['impliedStmts'] = len(implied)
511 log.info(f'{INDENT*0} Inference done {dict(stats)}.')
512 log.debug('Implied:')
513 log.debug(graphDump(implied))
514
515 if slog:
516 slog.render()
517 log.info(f'wrote {htmlLog}')
518
519 return implied
520
521 def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats, slog: Optional[StructuredLog]):
522 for i, rule in enumerate(self.rules):
523 self._logRuleApplicationHeader(workingSet, i, rule)
524 if slog:
525 slog.rule(workingSet, i, rule)
526 rule.applyRule(workingSet, implied, stats, slog, self.ruleStatementsIterationLimit)
527
528 def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
529 if not log.isEnabledFor(logging.DEBUG):
530 return
531
532 log.debug('')
533 log.debug(f'{INDENT*2} workingSet:')
534 # for j, stmt in enumerate(sorted(workingSet)):
535 # log.debug(f'{INDENT*3} ({j}) {stmt}')
536 log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}')
537
538 log.debug('')
539 log.debug(f'{INDENT*2}-applying rule {i}')
540 log.debug(f'{INDENT*3} rule def lhs:')
541 for stmt in sorted(r.lhs.graph.allChunks()):
542 log.debug(f'{INDENT*4} {stmt}')
543 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')