Mercurial > code > home > repos > homeauto
annotate service/mqtt_to_rdf/lhs_evaluation.py @ 1648:3059f31b2dfa
more performance work
author | drewp@bigasterisk.com |
---|---|
date | Fri, 17 Sep 2021 11:10:18 -0700 |
parents | 4bb6f593ebf3 |
children | 20474ad4968e |
rev | line source |
---|---|
1637 | 1 from dataclasses import dataclass |
1605 | 2 import logging |
3 from decimal import Decimal | |
1637 | 4 from candidate_binding import CandidateBinding |
1640 | 5 from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast |
1605 | 6 |
7 from prometheus_client import Summary | |
1636 | 8 from rdflib import RDF, Literal, Namespace, URIRef |
1637 | 9 from rdflib.graph import Graph |
1648 | 10 from rdflib.term import BNode, Node, Variable |
1605 | 11 |
1637 | 12 from inference_types import BindableTerm, Triple |
1607
b21885181e35
more modules, types. Maybe less repeated computation on BoundLhs
drewp@bigasterisk.com
parents:
1605
diff
changeset
|
13 |
1605 | 14 log = logging.getLogger('infer') |
15 | |
16 INDENT = ' ' | |
17 | |
18 ROOM = Namespace("http://projects.bigasterisk.com/room/") | |
19 LOG = Namespace('http://www.w3.org/2000/10/swap/log#') | |
20 MATH = Namespace('http://www.w3.org/2000/10/swap/math#') | |
21 | |
22 | |
23 def numericNode(n: Node): | |
24 if not isinstance(n, Literal): | |
25 raise TypeError(f'expected Literal, got {n=}') | |
26 val = n.toPython() | |
27 if not isinstance(val, (int, float, Decimal)): | |
28 raise TypeError(f'expected number, got {val=}') | |
29 return val | |
30 | |
31 | |
1634
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
32 def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]: |
1605 | 33 """"Do like Collection(g, subj) but also return all the |
34 triples that are involved in the list""" | |
35 out = [] | |
36 used = set() | |
37 cur = subj | |
38 while cur != RDF.nil: | |
1634
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
39 elem = graph.value(cur, RDF.first) |
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
40 if elem is None: |
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
41 raise ValueError('bad list') |
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
42 out.append(elem) |
1605 | 43 used.add((cur, RDF.first, out[-1])) |
44 | |
45 next = graph.value(cur, RDF.rest) | |
1634
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
46 if next is None: |
ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
drewp@bigasterisk.com
parents:
1607
diff
changeset
|
47 raise ValueError('bad list') |
1605 | 48 used.add((cur, RDF.rest, next)) |
49 | |
50 cur = next | |
51 return out, used | |
1637 | 52 |
53 | |
54 registeredFunctionTypes: List[Type['Function']] = [] | |
55 | |
56 | |
57 def register(cls: Type['Function']): | |
58 registeredFunctionTypes.append(cls) | |
59 return cls | |
60 | |
61 | |
62 class Function: | |
63 """any rule stmt that runs a function (not just a statement match)""" | |
1640 | 64 pred: URIRef |
1637 | 65 |
66 def __init__(self, stmt: Triple, ruleGraph: Graph): | |
67 self.stmt = stmt | |
68 if stmt[1] != self.pred: | |
69 raise TypeError | |
70 self.ruleGraph = ruleGraph | |
71 | |
72 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: | |
73 raise NotImplementedError | |
74 | |
75 def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: | |
76 out = [] | |
77 for op in self.getOperandNodes(existingBinding): | |
78 out.append(numericNode(op)) | |
79 | |
80 return out | |
81 | |
82 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
83 """either any new bindings this function makes (could be 0), or None if it doesn't match""" | |
84 raise NotImplementedError | |
85 | |
86 def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]: | |
87 objVar = self.stmt[2] | |
88 if not isinstance(objVar, Variable): | |
89 raise TypeError(f'expected Variable, got {objVar!r}') | |
90 return CandidateBinding({cast(BindableTerm, objVar): value}) | |
91 | |
1648 | 92 def usedStatements(self) -> Set[Triple]: |
93 '''stmts in self.graph (not including self.stmt, oddly) that are part of | |
94 this function setup and aren't to be matched literally''' | |
95 return set() | |
96 | |
1637 | 97 |
98 class SubjectFunction(Function): | |
99 """function that depends only on the subject term""" | |
100 | |
101 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: | |
102 return [existingBinding.applyTerm(self.stmt[0])] | |
103 | |
104 | |
105 class SubjectObjectFunction(Function): | |
106 """a filter function that depends on the subject and object terms""" | |
107 | |
108 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: | |
109 return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])] | |
110 | |
111 | |
112 class ListFunction(Function): | |
113 """function that takes an rdf list as input""" | |
114 | |
115 def usedStatements(self) -> Set[Triple]: | |
116 _, used = parseList(self.ruleGraph, self.stmt[0]) | |
117 return used | |
118 | |
119 def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: | |
120 operands, _ = parseList(self.ruleGraph, self.stmt[0]) | |
121 return [existingBinding.applyTerm(x) for x in operands] | |
122 | |
123 | |
124 @register | |
125 class Gt(SubjectObjectFunction): | |
126 pred = MATH['greaterThan'] | |
127 | |
128 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
129 [x, y] = self.getNumericOperands(existingBinding) | |
130 if x > y: | |
131 return CandidateBinding({}) # no new values; just allow matching to keep going | |
132 | |
133 | |
134 @register | |
135 class AsFarenheit(SubjectFunction): | |
136 pred = ROOM['asFarenheit'] | |
137 | |
138 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
139 [x] = self.getNumericOperands(existingBinding) | |
140 f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32)) | |
141 return self.valueInObjectTerm(f) | |
142 | |
143 | |
144 @register | |
145 class Sum(ListFunction): | |
146 pred = MATH['sum'] | |
147 | |
148 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | |
149 f = Literal(sum(self.getNumericOperands(existingBinding))) | |
150 return self.valueInObjectTerm(f) | |
151 | |
1648 | 152 ### registration is done |
1637 | 153 |
1640 | 154 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes) |
155 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]: | |
156 try: | |
157 yield _byPred[pred] | |
158 except KeyError: | |
159 return | |
160 | |
1637 | 161 |
162 def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]: | |
163 usedByFuncs: Set[Triple] = set() # don't worry about matching these | |
164 for s in graph: | |
165 for cls in functionsFor(pred=s[1]): | |
1648 | 166 usedByFuncs.update(cls(s, graph).usedStatements()) |
1637 | 167 return usedByFuncs |
1640 | 168 |
169 | |
170 def rulePredicates() -> Set[URIRef]: | |
171 return set(c.pred for c in registeredFunctionTypes) |