Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1605:449746d1598f
WIP move evaluation to new file
author | drewp@bigasterisk.com |
---|---|
date | Mon, 06 Sep 2021 01:13:55 -0700 |
parents | 7f8bf68534ed |
children | b21885181e35 |
comparison
equal
deleted
inserted
replaced
1604:e78464befd24 | 1605:449746d1598f |
---|---|
13 from prometheus_client import Summary | 13 from prometheus_client import Summary |
14 from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef | 14 from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef |
15 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 15 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
16 from rdflib.term import Node, Variable | 16 from rdflib.term import Node, Variable |
17 | 17 |
18 from lhs_evaluation import EvaluationFailed, Evaluation | |
19 | |
18 log = logging.getLogger('infer') | 20 log = logging.getLogger('infer') |
19 INDENT = ' ' | 21 INDENT = ' ' |
20 | 22 |
21 Triple = Tuple[Node, Node, Node] | 23 Triple = Tuple[Node, Node, Node] |
22 Rule = Tuple[Graph, Node, Graph] | 24 Rule = Tuple[Graph, Node, Graph] |
30 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') | 32 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') |
31 | 33 |
32 # Graph() makes a BNode if you don't pass | 34 # Graph() makes a BNode if you don't pass |
33 # identifier, which can be a bottleneck. | 35 # identifier, which can be a bottleneck. |
34 GRAPH_ID = URIRef('dont/care') | 36 GRAPH_ID = URIRef('dont/care') |
35 | |
36 class EvaluationFailed(ValueError): | |
37 """e.g. we were given (5 math:greaterThan 6)""" | |
38 | 37 |
39 | 38 |
40 class BindingUnknown(ValueError): | 39 class BindingUnknown(ValueError): |
41 """e.g. we were asked to make the bound version | 40 """e.g. we were asked to make the bound version |
42 of (A B ?c) and we don't have a binding for ?c | 41 of (A B ?c) and we don't have a binding for ?c |
154 if not varsAndBnodesInStmt: | 153 if not varsAndBnodesInStmt: |
155 self.staticRuleStmts.add(ruleStmt) | 154 self.staticRuleStmts.add(ruleStmt) |
156 | 155 |
157 self.evaluations = list(Evaluation.findEvals(self.graph)) | 156 self.evaluations = list(Evaluation.findEvals(self.graph)) |
158 | 157 |
159 | |
160 def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]: | 158 def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]: |
161 """bindings that fit the LHS of a rule, using statements from workingSet and functions | 159 """bindings that fit the LHS of a rule, using statements from workingSet and functions |
162 from LHS""" | 160 from LHS""" |
163 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') | 161 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') |
164 self.stats['findCandidateBindingsCalls'] += 1 | 162 self.stats['findCandidateBindingsCalls'] += 1 |
172 orderedVars, orderedValueSets = _organize(candidateTermMatches) | 170 orderedVars, orderedValueSets = _organize(candidateTermMatches) |
173 | 171 |
174 self._logCandidates(orderedVars, orderedValueSets) | 172 self._logCandidates(orderedVars, orderedValueSets) |
175 | 173 |
176 log.debug(f'{INDENT*3} trying all permutations:') | 174 log.debug(f'{INDENT*3} trying all permutations:') |
177 | |
178 | 175 |
179 for perm in itertools.product(*orderedValueSets): | 176 for perm in itertools.product(*orderedValueSets): |
180 binding = CandidateBinding(dict(zip(orderedVars, perm))) | 177 binding = CandidateBinding(dict(zip(orderedVars, perm))) |
181 log.debug('') | 178 log.debug('') |
182 log.debug(f'{INDENT*4}*trying {binding}') | 179 log.debug(f'{INDENT*4}*trying {binding}') |
256 log.debug(f'{INDENT*4} {v!r} could be:') | 253 log.debug(f'{INDENT*4} {v!r} could be:') |
257 for val in vals: | 254 for val in vals: |
258 log.debug(f'{INDENT*5}{val!r}') | 255 log.debug(f'{INDENT*5}{val!r}') |
259 | 256 |
260 | 257 |
261 class Evaluation: | |
262 """some lhs statements need to be evaluated with a special function | |
263 (e.g. math) and then not considered for the rest of the rule-firing | |
264 process. It's like they already 'matched' something, so they don't need | |
265 to match a statement from the known-true working set. | |
266 | |
267 One Evaluation instance is for one function call. | |
268 """ | |
269 | |
270 @staticmethod | |
271 def findEvals(graph: Graph) -> Iterator['Evaluation']: | |
272 for stmt in graph.triples((None, MATH['sum'], None)): | |
273 operands, operandsStmts = _parseList(graph, stmt[0]) | |
274 yield Evaluation(operands, stmt, operandsStmts) | |
275 | |
276 for stmt in graph.triples((None, MATH['greaterThan'], None)): | |
277 yield Evaluation([stmt[0], stmt[2]], stmt, []) | |
278 | |
279 for stmt in graph.triples((None, ROOM['asFarenheit'], None)): | |
280 yield Evaluation([stmt[0]], stmt, []) | |
281 | |
282 # internal, use findEvals | |
283 def __init__(self, operands: List[Node], mainStmt: Triple, otherStmts: Iterable[Triple]) -> None: | |
284 self.operands = operands | |
285 self.operandsStmts = Graph(identifier=GRAPH_ID) | |
286 self.operandsStmts += otherStmts # may grow | |
287 self.operandsStmts.add(mainStmt) | |
288 self.stmt = mainStmt | |
289 | |
290 def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]: | |
291 """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?""" | |
292 pred = self.stmt[1] | |
293 objVar: Node = self.stmt[2] | |
294 boundOperands = [] | |
295 for op in self.operands: | |
296 if isinstance(op, Variable): | |
297 try: | |
298 op = inputBindings[op] | |
299 except KeyError: | |
300 return {}, self.operandsStmts | |
301 | |
302 boundOperands.append(op) | |
303 | |
304 if pred == MATH['sum']: | |
305 obj = Literal(sum(map(numericNode, boundOperands))) | |
306 if not isinstance(objVar, Variable): | |
307 raise TypeError(f'expected Variable, got {objVar!r}') | |
308 res: Dict[BindableTerm, Node] = {objVar: obj} | |
309 elif pred == ROOM['asFarenheit']: | |
310 if len(boundOperands) != 1: | |
311 raise ValueError(":asFarenheit takes 1 subject operand") | |
312 f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32) | |
313 if not isinstance(objVar, Variable): | |
314 raise TypeError(f'expected Variable, got {objVar!r}') | |
315 res: Dict[BindableTerm, Node] = {objVar: f} | |
316 elif pred == MATH['greaterThan']: | |
317 if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])): | |
318 raise EvaluationFailed() | |
319 res: Dict[BindableTerm, Node] = {} | |
320 else: | |
321 raise NotImplementedError(repr(pred)) | |
322 | |
323 return res, self.operandsStmts | |
324 | |
325 | |
326 def numericNode(n: Node): | |
327 if not isinstance(n, Literal): | |
328 raise TypeError(f'expected Literal, got {n=}') | |
329 val = n.toPython() | |
330 if not isinstance(val, (int, float, Decimal)): | |
331 raise TypeError(f'expected number, got {val=}') | |
332 return val | |
333 | |
334 | |
335 class Inference: | 258 class Inference: |
336 | 259 |
337 def __init__(self) -> None: | 260 def __init__(self) -> None: |
338 self.rules = ConjunctiveGraph() | 261 self.rules = ConjunctiveGraph() |
339 | 262 |
349 """ | 272 """ |
350 returns new graph of inferred statements. | 273 returns new graph of inferred statements. |
351 """ | 274 """ |
352 log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:') | 275 log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:') |
353 startTime = time.time() | 276 startTime = time.time() |
354 self.stats: Dict[str, Union[int,float]] = defaultdict(lambda: 0) | 277 self.stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0) |
355 # everything that is true: the input graph, plus every rule conclusion we can make | 278 # everything that is true: the input graph, plus every rule conclusion we can make |
356 workingSet = Graph() | 279 workingSet = Graph() |
357 workingSet += graph | 280 workingSet += graph |
358 | 281 |
359 # just the statements that came from RHS's of rules that fired. | 282 # just the statements that came from RHS's of rules that fired. |
408 log.debug(f'{INDENT*5} adding {newStmt=}') | 331 log.debug(f'{INDENT*5} adding {newStmt=}') |
409 workingSet.add(newStmt) | 332 workingSet.add(newStmt) |
410 implied.add(newStmt) | 333 implied.add(newStmt) |
411 | 334 |
412 | 335 |
413 def _parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]: | |
414 """"Do like Collection(g, subj) but also return all the | |
415 triples that are involved in the list""" | |
416 out = [] | |
417 used = set() | |
418 cur = subj | |
419 while cur != RDF.nil: | |
420 out.append(graph.value(cur, RDF.first)) | |
421 used.add((cur, RDF.first, out[-1])) | |
422 | |
423 next = graph.value(cur, RDF.rest) | |
424 used.add((cur, RDF.rest, next)) | |
425 | |
426 cur = next | |
427 return out, used | |
428 | |
429 | |
430 def graphDump(g: Union[Graph, List[Triple]]): | 336 def graphDump(g: Union[Graph, List[Triple]]): |
431 if not isinstance(g, Graph): | 337 if not isinstance(g, Graph): |
432 g2 = Graph() | 338 g2 = Graph() |
433 g2 += g | 339 g2 += g |
434 g = g2 | 340 g = g2 |