comparison service/mqtt_to_rdf/inference.py @ 1607:b21885181e35

more modules, types. Maybe less repeated computation on BoundLhs
author drewp@bigasterisk.com
date Mon, 06 Sep 2021 15:38:48 -0700
parents 449746d1598f
children f928eb06a4f6
comparison
equal deleted inserted replaced
1606:6cf39d43fd40 1607:b21885181e35
5 import itertools 5 import itertools
6 import logging 6 import logging
7 import time 7 import time
8 from collections import defaultdict 8 from collections import defaultdict
9 from dataclasses import dataclass, field 9 from dataclasses import dataclass, field
10 from decimal import Decimal 10 from typing import Dict, Iterator, List, Set, Tuple, Union, cast
11 from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union, cast
12 11
13 from prometheus_client import Summary 12 from prometheus_client import Summary
14 from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef 13 from rdflib import BNode, Graph, Namespace, URIRef
15 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
16 from rdflib.term import Node, Variable 15 from rdflib.term import Node, Variable
17 16
18 from lhs_evaluation import EvaluationFailed, Evaluation 17 from candidate_binding import CandidateBinding
18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
19 from lhs_evaluation import Evaluation
19 20
20 log = logging.getLogger('infer') 21 log = logging.getLogger('infer')
21 INDENT = ' ' 22 INDENT = ' '
22
23 Triple = Tuple[Node, Node, Node]
24 Rule = Tuple[Graph, Node, Graph]
25 BindableTerm = Union[Variable, BNode]
26 ReadOnlyWorkingSet = ReadOnlyGraphAggregate
27 23
28 INFER_CALLS = Summary('read_rules_calls', 'calls') 24 INFER_CALLS = Summary('read_rules_calls', 'calls')
29 25
30 ROOM = Namespace("http://projects.bigasterisk.com/room/") 26 ROOM = Namespace("http://projects.bigasterisk.com/room/")
31 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') 27 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
34 # Graph() makes a BNode if you don't pass 30 # Graph() makes a BNode if you don't pass
35 # identifier, which can be a bottleneck. 31 # identifier, which can be a bottleneck.
36 GRAPH_ID = URIRef('dont/care') 32 GRAPH_ID = URIRef('dont/care')
37 33
38 34
39 class BindingUnknown(ValueError):
40 """e.g. we were asked to make the bound version
41 of (A B ?c) and we don't have a binding for ?c
42 """
43
44
45 @dataclass
46 class CandidateBinding:
47 binding: Dict[BindableTerm, Node]
48
49 def __repr__(self):
50 b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items()))
51 return f'CandidateBinding({b})'
52
53 def apply(self, g: Graph) -> Iterator[Triple]:
54 for stmt in g:
55 try:
56 bound = (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2]))
57 except BindingUnknown:
58 continue
59 yield bound
60
61 def _applyTerm(self, term: Node):
62 if isinstance(term, (Variable, BNode)):
63 if term in self.binding:
64 return self.binding[term]
65 else:
66 raise BindingUnknown()
67 return term
68
69 def applyFunctions(self, lhs) -> Graph:
70 """may grow the binding with some results"""
71 usedByFuncs = Graph(identifier=GRAPH_ID)
72 while True:
73 delta = self._applyFunctionsIteration(lhs, usedByFuncs)
74 if delta == 0:
75 break
76 return usedByFuncs
77
78 def _applyFunctionsIteration(self, lhs, usedByFuncs: Graph):
79 before = len(self.binding)
80 delta = 0
81 for ev in lhs.evaluations:
82 log.debug(f'{INDENT*3} found Evaluation')
83
84 newBindings, usedGraph = ev.resultBindings(self.binding)
85 usedByFuncs += usedGraph
86 self._addNewBindings(newBindings)
87 delta = len(self.binding) - before
88 if log.isEnabledFor(logging.DEBUG):
89 dump = "(...)"
90 if cast(int, usedGraph.__len__()) < 20:
91 dump = graphDump(usedGraph)
92 log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings')
93 return delta
94
95 def _addNewBindings(self, newBindings):
96 for k, v in newBindings.items():
97 if k in self.binding and self.binding[k] != v:
98 raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}')
99 self.binding[k] = v
100
101 def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool:
102 """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
103 boundLhs = list(self.apply(lhs.graph))
104 boundUsedByFuncs = list(self.apply(usedByFuncs))
105
106 self._logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs)
107
108 for stmt in boundLhs:
109 log.debug(f'{INDENT*4} check for {stmt}')
110
111 if stmt in boundUsedByFuncs:
112 pass
113 elif stmt in workingSet:
114 pass
115 else:
116 log.debug(f'{INDENT*5} stmt not known to be true')
117 return False
118 return True
119
120 def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs):
121 if not log.isEnabledFor(logging.DEBUG):
122 return
123 log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:')
124 for stmt in sorted(boundLhs):
125 log.debug(f'{INDENT*4}|{INDENT} {stmt}')
126
127 # log.debug(f'{INDENT*4}| and against this workingSet:')
128 # for stmt in sorted(workingSet):
129 # log.debug(f'{INDENT*4}|{INDENT} {stmt}')
130
131 stmts = sorted(boundUsedByFuncs)
132 if stmts:
133 log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:')
134 for stmt in stmts:
135 log.debug(f'{INDENT*4}|{INDENT} {stmt}')
136 log.debug(f'{INDENT*4}\\')
137
138
139 @dataclass 35 @dataclass
140 class Lhs: 36 class Lhs:
141 graph: Graph 37 graph: Graph
142 stats: Dict
143 38
144 staticRuleStmts: Graph = field(default_factory=Graph) 39 staticRuleStmts: Graph = field(default_factory=Graph)
145 lhsBindables: Set[BindableTerm] = field(default_factory=set) 40 lhsBindables: Set[BindableTerm] = field(default_factory=set)
146 lhsBnodes: Set[BNode] = field(default_factory=set) 41 lhsBnodes: Set[BNode] = field(default_factory=set)
147 42
153 if not varsAndBnodesInStmt: 48 if not varsAndBnodesInStmt:
154 self.staticRuleStmts.add(ruleStmt) 49 self.staticRuleStmts.add(ruleStmt)
155 50
156 self.evaluations = list(Evaluation.findEvals(self.graph)) 51 self.evaluations = list(Evaluation.findEvals(self.graph))
157 52
158 def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]: 53 def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
159 """bindings that fit the LHS of a rule, using statements from workingSet and functions 54 """bindings that fit the LHS of a rule, using statements from workingSet and functions
160 from LHS""" 55 from LHS"""
161 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') 56 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}')
162 self.stats['findCandidateBindingsCalls'] += 1 57 stats['findCandidateBindingsCalls'] += 1
163 58
164 if not self._allStaticStatementsMatch(workingSet): 59 if not self._allStaticStatementsMatch(workingSet):
165 self.stats['findCandidateBindingEarlyExits'] += 1 60 stats['findCandidateBindingEarlyExits'] += 1
166 return 61 return
167 62
63 for binding in self._possibleBindings(workingSet, stats):
64 log.debug('')
65 log.debug(f'{INDENT*4}*trying {binding.binding}')
66
67 if not binding.verify(workingSet):
68 log.debug(f'{INDENT*4} this binding did not verify')
69 stats['permCountFailingVerify'] += 1
70 continue
71
72 stats['permCountSucceeding'] += 1
73 yield binding
74
75 def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']:
76 """this yields at least the working bindings, and possibly others"""
168 candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet) 77 candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet)
169 78
170 orderedVars, orderedValueSets = _organize(candidateTermMatches) 79 orderedVars, orderedValueSets = _organize(candidateTermMatches)
171
172 self._logCandidates(orderedVars, orderedValueSets) 80 self._logCandidates(orderedVars, orderedValueSets)
173 81
174 log.debug(f'{INDENT*3} trying all permutations:') 82 log.debug(f'{INDENT*3} trying all permutations:')
175
176 for perm in itertools.product(*orderedValueSets): 83 for perm in itertools.product(*orderedValueSets):
177 binding = CandidateBinding(dict(zip(orderedVars, perm)))
178 log.debug('')
179 log.debug(f'{INDENT*4}*trying {binding}')
180
181 try: 84 try:
182 usedByFuncs = binding.applyFunctions(self) 85 yield BoundLhs(self, CandidateBinding(dict(zip(orderedVars, perm))))
183 except EvaluationFailed: 86 except EvaluationFailed:
184 self.stats['permCountFailingEval'] += 1 87 stats['permCountFailingEval'] += 1
185 continue
186
187 if not binding.verify(self, workingSet, usedByFuncs):
188 log.debug(f'{INDENT*4} this binding did not verify')
189 self.stats['permCountFailingVerify'] += 1
190 continue
191
192 self.stats['permCountSucceeding'] += 1
193 yield binding
194 88
195 def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool: 89 def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool:
196 for ruleStmt in self.staticRuleStmts: 90 for ruleStmt in self.staticRuleStmts:
197 if ruleStmt not in workingSet: 91 if ruleStmt not in workingSet:
198 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') 92 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule')
234 else: 128 else:
235 for v, vals in bindingsFromStatement.items(): 129 for v, vals in bindingsFromStatement.items():
236 log.debug(f'{INDENT*5} {v=} {vals=}') 130 log.debug(f'{INDENT*5} {v=} {vals=}')
237 yield v, vals 131 yield v, vals
238 132
239 def graphWithoutEvals(self, binding: CandidateBinding) -> Graph:
240 g = Graph(identifier=GRAPH_ID)
241 usedByFuncs = binding.applyFunctions(self)
242
243 for stmt in self.graph:
244 if stmt not in usedByFuncs:
245 g.add(stmt)
246 return g
247
248 def _logCandidates(self, orderedVars, orderedValueSets): 133 def _logCandidates(self, orderedVars, orderedValueSets):
249 if not log.isEnabledFor(logging.DEBUG): 134 if not log.isEnabledFor(logging.DEBUG):
250 return 135 return
251 log.debug(f'{INDENT*3} resulting candidate terms:') 136 log.debug(f'{INDENT*3} resulting candidate terms:')
252 for v, vals in zip(orderedVars, orderedValueSets): 137 for v, vals in zip(orderedVars, orderedValueSets):
253 log.debug(f'{INDENT*4} {v!r} could be:') 138 log.debug(f'{INDENT*4} {v!r} could be:')
254 for val in vals: 139 for val in vals:
255 log.debug(f'{INDENT*5}{val!r}') 140 log.debug(f'{INDENT*5}{val!r}')
256 141
257 142
143 @dataclass
144 class BoundLhs:
145 lhs: Lhs
146 binding: CandidateBinding
147
148 def __post_init__(self):
149 self.usedByFuncs = Graph(identifier=GRAPH_ID)
150 self.graphWithoutEvals = self._graphWithoutEvals()
151
152 def _graphWithoutEvals(self) -> Graph:
153 g = Graph(identifier=GRAPH_ID)
154 self._applyFunctions()
155
156 for stmt in self.lhs.graph:
157 if stmt not in self.usedByFuncs:
158 g.add(stmt)
159 return g
160
161 def _applyFunctions(self):
162 """may grow the binding with some results"""
163 while True:
164 delta = self._applyFunctionsIteration()
165 if delta == 0:
166 break
167
168 def _applyFunctionsIteration(self):
169 before = len(self.binding.binding)
170 delta = 0
171 for ev in self.lhs.evaluations:
172 log.debug(f'{INDENT*3} found Evaluation')
173
174 newBindings, usedGraph = ev.resultBindings(self.binding)
175 self.usedByFuncs += usedGraph
176 self.binding.addNewBindings(newBindings)
177 delta = len(self.binding.binding) - before
178 if log.isEnabledFor(logging.DEBUG):
179 dump = "(...)"
180 if cast(int, usedGraph.__len__()) < 20:
181 dump = graphDump(usedGraph)
182 log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings')
183 return delta
184
185
186 def verify(self, workingSet: ReadOnlyWorkingSet) -> bool:
187 """Can this bound lhs be true all at once in workingSet?"""
188 boundLhs = list(self.binding.apply(self.lhs.graph))
189 boundUsedByFuncs = list(self.binding.apply(self.usedByFuncs))
190
191 self._logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs)
192
193 for stmt in boundLhs:
194 log.debug(f'{INDENT*4} check for {stmt}')
195
196 if stmt in boundUsedByFuncs:
197 pass
198 elif stmt in workingSet:
199 pass
200 else:
201 log.debug(f'{INDENT*5} stmt not known to be true')
202 return False
203 return True
204
205 def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs):
206 if not log.isEnabledFor(logging.DEBUG):
207 return
208 log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:')
209 for stmt in sorted(boundLhs):
210 log.debug(f'{INDENT*4}|{INDENT} {stmt}')
211
212 # log.debug(f'{INDENT*4}| and against this workingSet:')
213 # for stmt in sorted(workingSet):
214 # log.debug(f'{INDENT*4}|{INDENT} {stmt}')
215
216 stmts = sorted(boundUsedByFuncs)
217 if stmts:
218 log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:')
219 for stmt in stmts:
220 log.debug(f'{INDENT*4}|{INDENT} {stmt}')
221 log.debug(f'{INDENT*4}\\')
222
223
224 @dataclass
225 class Rule:
226 lhsGraph: Graph
227 rhsGraph: Graph
228
229 def __post_init__(self):
230 self.lhs = Lhs(self.lhsGraph)
231
232
258 class Inference: 233 class Inference:
259 234
260 def __init__(self) -> None: 235 def __init__(self) -> None:
261 self.rules = ConjunctiveGraph() 236 self.rules = []
262 237
263 def setRules(self, g: ConjunctiveGraph): 238 def setRules(self, g: ConjunctiveGraph):
264 self.rules = ConjunctiveGraph() 239 self.rules: List[Rule] = []
265 for stmt in g: 240 for stmt in g:
266 if stmt[1] == LOG['implies']: 241 if stmt[1] == LOG['implies']:
267 self.rules.add(stmt) 242 self.rules.append(Rule(stmt[0], stmt[2]))
268 # others should go to a default working set? 243 # others should go to a default working set?
269 244
270 @INFER_CALLS.time() 245 @INFER_CALLS.time()
271 def infer(self, graph: Graph): 246 def infer(self, graph: Graph):
272 """ 247 """
273 returns new graph of inferred statements. 248 returns new graph of inferred statements.
274 """ 249 """
275 log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:') 250 log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:')
276 startTime = time.time() 251 startTime = time.time()
277 self.stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0) 252 stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
278 # everything that is true: the input graph, plus every rule conclusion we can make 253 # everything that is true: the input graph, plus every rule conclusion we can make
279 workingSet = Graph() 254 workingSet = Graph()
280 workingSet += graph 255 workingSet += graph
281 256
282 # just the statements that came from RHS's of rules that fired. 257 # just the statements that came from RHS's of rules that fired.
283 implied = ConjunctiveGraph() 258 implied = ConjunctiveGraph()
284 259
285 bailout_iterations = 100 260 bailout_iterations = 100
286 delta = 1 261 delta = 1
287 self.stats['initWorkingSet'] = cast(int, workingSet.__len__()) 262 stats['initWorkingSet'] = cast(int, workingSet.__len__())
288 while delta > 0 and bailout_iterations > 0: 263 while delta > 0 and bailout_iterations > 0:
289 log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') 264 log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
290 bailout_iterations -= 1 265 bailout_iterations -= 1
291 delta = -len(implied) 266 delta = -len(implied)
292 self._iterateAllRules(workingSet, implied) 267 self._iterateAllRules(workingSet, implied, stats)
293 delta += len(implied) 268 delta += len(implied)
294 self.stats['iterations'] += 1 269 stats['iterations'] += 1
295 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts') 270 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
296 self.stats['timeSpent'] = round(time.time() - startTime, 3) 271 stats['timeSpent'] = round(time.time() - startTime, 3)
297 self.stats['impliedStmts'] = len(implied) 272 stats['impliedStmts'] = len(implied)
298 log.info(f'{INDENT*0} Inference done {dict(self.stats)}. Implied:') 273 log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:')
299 for st in implied: 274 for st in implied:
300 log.info(f'{INDENT*1} {st}') 275 log.info(f'{INDENT*1} {st}')
301 return implied 276 return implied
302 277
303 def _iterateAllRules(self, workingSet: Graph, implied: Graph): 278 def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
304 for i, r in enumerate(self.rules): 279 for i, r in enumerate(self.rules):
305 self._logRuleApplicationHeader(workingSet, i, r) 280 self._logRuleApplicationHeader(workingSet, i, r)
306 _applyRule(Lhs(r[0], self.stats), r[2], workingSet, implied, self.stats) 281 _applyRule(r.lhs, r.rhsGraph, workingSet, implied, stats)
307 282
308 def _logRuleApplicationHeader(self, workingSet, i, r): 283 def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
309 if not log.isEnabledFor(logging.DEBUG): 284 if not log.isEnabledFor(logging.DEBUG):
310 return 285 return
311 286
312 log.debug('') 287 log.debug('')
313 log.debug(f'{INDENT*2} workingSet:') 288 log.debug(f'{INDENT*2} workingSet:')
314 for j, stmt in enumerate(sorted(workingSet)): 289 for j, stmt in enumerate(sorted(workingSet)):
315 log.debug(f'{INDENT*3} ({j}) {stmt}') 290 log.debug(f'{INDENT*3} ({j}) {stmt}')
316 291
317 log.debug('') 292 log.debug('')
318 log.debug(f'{INDENT*2}-applying rule {i}') 293 log.debug(f'{INDENT*2}-applying rule {i}')
319 log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}') 294 log.debug(f'{INDENT*3} rule def lhs: {graphDump(r.lhsGraph)}')
320 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}') 295 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
321 296
322 297
323 def _applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph, stats: Dict): 298 def _applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph, stats: Dict):
324 for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])): 299 for bound in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
325 log.debug(f'{INDENT*3} rule has a working binding:') 300 log.debug(f'{INDENT*3} rule has a working binding:')
326 301
327 for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)): 302 for lhsBoundStmt in bound.binding.apply(bound.graphWithoutEvals):
328 log.debug(f'{INDENT*5} adding {lhsBoundStmt=}') 303 log.debug(f'{INDENT*5} adding {lhsBoundStmt=}')
329 workingSet.add(lhsBoundStmt) 304 workingSet.add(lhsBoundStmt)
330 for newStmt in binding.apply(rhs): 305 for newStmt in bound.binding.apply(rhs):
331 log.debug(f'{INDENT*5} adding {newStmt=}') 306 log.debug(f'{INDENT*5} adding {newStmt=}')
332 workingSet.add(newStmt) 307 workingSet.add(newStmt)
333 implied.add(newStmt) 308 implied.add(newStmt)
334 309
335 310
336 def graphDump(g: Union[Graph, List[Triple]]): 311 def graphDump(g: Union[Graph, List[Triple]]):
337 if not isinstance(g, Graph): 312 if not isinstance(g, Graph):
313 log.warning(f"it's a {type(g)}")
338 g2 = Graph() 314 g2 = Graph()
339 g2 += g 315 g2 += g
340 g = g2 316 g = g2
341 g.bind('', ROOM) 317 g.bind('', ROOM)
342 g.bind('ex', Namespace('http://example.com/')) 318 g.bind('ex', Namespace('http://example.com/'))