comparison service/mqtt_to_rdf/lhs_evaluation.py @ 1605:449746d1598f

WIP move evaluation to new file
author drewp@bigasterisk.com
date Mon, 06 Sep 2021 01:13:55 -0700
parents
children b21885181e35
comparison
equal deleted inserted replaced
1604:e78464befd24 1605:449746d1598f
1 import logging
2 from dataclasses import dataclass, field
3 from decimal import Decimal
4 from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union, cast
5
6 from prometheus_client import Summary
7 from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef
8 from rdflib.graph import ReadOnlyGraphAggregate
9 from rdflib.term import Node, Variable
10
11 log = logging.getLogger('infer')
12
13 INDENT = ' '
14
15 Triple = Tuple[Node, Node, Node]
16 Rule = Tuple[Graph, Node, Graph]
17 BindableTerm = Union[Variable, BNode]
18 ReadOnlyWorkingSet = ReadOnlyGraphAggregate
19
20 ROOM = Namespace("http://projects.bigasterisk.com/room/")
21 LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
22 MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
23
24 # Graph() makes a BNode if you don't pass
25 # identifier, which can be a bottleneck.
26 GRAPH_ID = URIRef('dont/care')
27
28
29 class EvaluationFailed(ValueError):
30 """e.g. we were given (5 math:greaterThan 6)"""
31
32
33 class Evaluation:
34 """some lhs statements need to be evaluated with a special function
35 (e.g. math) and then not considered for the rest of the rule-firing
36 process. It's like they already 'matched' something, so they don't need
37 to match a statement from the known-true working set.
38
39 One Evaluation instance is for one function call.
40 """
41
42 @staticmethod
43 def findEvals(graph: Graph) -> Iterator['Evaluation']:
44 for stmt in graph.triples((None, MATH['sum'], None)):
45 operands, operandsStmts = _parseList(graph, stmt[0])
46 yield Evaluation(operands, stmt, operandsStmts)
47
48 for stmt in graph.triples((None, MATH['greaterThan'], None)):
49 yield Evaluation([stmt[0], stmt[2]], stmt, [])
50
51 for stmt in graph.triples((None, ROOM['asFarenheit'], None)):
52 yield Evaluation([stmt[0]], stmt, [])
53
54 # internal, use findEvals
55 def __init__(self, operands: List[Node], mainStmt: Triple, otherStmts: Iterable[Triple]) -> None:
56 self.operands = operands
57 self.operandsStmts = Graph(identifier=GRAPH_ID)
58 self.operandsStmts += otherStmts # may grow
59 self.operandsStmts.add(mainStmt)
60 self.stmt = mainStmt
61
62 def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]:
63 """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?"""
64 pred = self.stmt[1]
65 objVar: Node = self.stmt[2]
66 boundOperands = []
67 for op in self.operands:
68 if isinstance(op, Variable):
69 try:
70 op = inputBindings[op]
71 except KeyError:
72 return {}, self.operandsStmts
73
74 boundOperands.append(op)
75
76 if pred == MATH['sum']:
77 obj = Literal(sum(map(numericNode, boundOperands)))
78 if not isinstance(objVar, Variable):
79 raise TypeError(f'expected Variable, got {objVar!r}')
80 res: Dict[BindableTerm, Node] = {objVar: obj}
81 elif pred == ROOM['asFarenheit']:
82 if len(boundOperands) != 1:
83 raise ValueError(":asFarenheit takes 1 subject operand")
84 f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32)
85 if not isinstance(objVar, Variable):
86 raise TypeError(f'expected Variable, got {objVar!r}')
87 res: Dict[BindableTerm, Node] = {objVar: f}
88 elif pred == MATH['greaterThan']:
89 if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])):
90 raise EvaluationFailed()
91 res: Dict[BindableTerm, Node] = {}
92 else:
93 raise NotImplementedError(repr(pred))
94
95 return res, self.operandsStmts
96
97
98 def numericNode(n: Node):
99 if not isinstance(n, Literal):
100 raise TypeError(f'expected Literal, got {n=}')
101 val = n.toPython()
102 if not isinstance(val, (int, float, Decimal)):
103 raise TypeError(f'expected number, got {val=}')
104 return val
105
106
107 def _parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
108 """"Do like Collection(g, subj) but also return all the
109 triples that are involved in the list"""
110 out = []
111 used = set()
112 cur = subj
113 while cur != RDF.nil:
114 out.append(graph.value(cur, RDF.first))
115 used.add((cur, RDF.first, out[-1]))
116
117 next = graph.value(cur, RDF.rest)
118 used.add((cur, RDF.rest, next))
119
120 cur = next
121 return out, used