Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1634:ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
author | drewp@bigasterisk.com |
---|---|
date | Sun, 12 Sep 2021 23:48:43 -0700 |
parents | 6107603ed455 |
children | 22d481f0a924 |
comparison
equal
deleted
inserted
replaced
1633:6107603ed455 | 1634:ba59cfc3c747 |
---|---|
5 import itertools | 5 import itertools |
6 import logging | 6 import logging |
7 import time | 7 import time |
8 from collections import defaultdict | 8 from collections import defaultdict |
9 from dataclasses import dataclass | 9 from dataclasses import dataclass |
10 from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast | 10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) |
11 | 11 |
12 from prometheus_client import Histogram, Summary | 12 from prometheus_client import Histogram, Summary |
13 from rdflib import BNode, Graph, Namespace | 13 from rdflib import RDF, BNode, Graph, Namespace |
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
15 from rdflib.term import Literal, Node, Variable | 15 from rdflib.term import Literal, Node, Variable |
16 | 16 |
17 from candidate_binding import CandidateBinding | 17 from candidate_binding import CandidateBinding |
18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) | 18 from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple) |
19 from lhs_evaluation import Decimal, Evaluation, numericNode | 19 from lhs_evaluation import Decimal, Evaluation, numericNode, parseList |
20 | 20 |
21 log = logging.getLogger('infer') | 21 log = logging.getLogger('infer') |
22 INDENT = ' ' | 22 INDENT = ' ' |
23 | 23 |
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') | 24 INFER_CALLS = Summary('inference_infer_calls', 'calls') |
56 | 56 |
57 This iterator is restartable.""" | 57 This iterator is restartable.""" |
58 lhsStmt: Triple | 58 lhsStmt: Triple |
59 prev: Optional['StmtLooper'] | 59 prev: Optional['StmtLooper'] |
60 workingSet: ReadOnlyWorkingSet | 60 workingSet: ReadOnlyWorkingSet |
61 parent: 'Lhs' # just for lhs.graph, really | |
61 | 62 |
62 def __repr__(self): | 63 def __repr__(self): |
63 return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' | 64 return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' |
64 | 65 |
65 def __post_init__(self): | 66 def __post_init__(self): |
96 augmentedWorkingSet = self._myWorkingSetMatches | 97 augmentedWorkingSet = self._myWorkingSetMatches |
97 else: | 98 else: |
98 augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, | 99 augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, |
99 returnBoundStatementsOnly=False)) | 100 returnBoundStatementsOnly=False)) |
100 | 101 |
101 log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}') | 102 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') |
102 | 103 |
103 log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements') | 104 if self._advanceWithPlainMatches(augmentedWorkingSet): |
105 return | |
106 | |
107 if self._advanceWithBoolRules(): | |
108 return | |
109 | |
110 curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) | |
111 [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) | |
112 | |
113 fullWorkingSet = self.workingSet + self.parent.graph | |
114 boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False)) | |
115 log.debug(f'{fullWorkingSet.__len__()=} {len(boundFullWorkingSet)=}') | |
116 | |
117 if self._advanceWithFunctions(augmentedWorkingSet, boundFullWorkingSet, lhsStmtBound): | |
118 return | |
119 | |
120 log.debug(f'{INDENT*6} {self} is past end') | |
121 self._pastEnd = True | |
122 | |
123 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: | |
124 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') | |
104 for s in augmentedWorkingSet: | 125 for s in augmentedWorkingSet: |
105 log.debug(f'{INDENT*7} {s}') | 126 log.debug(f'{INDENT*7} {s}') |
106 | 127 |
107 for i, stmt in enumerate(augmentedWorkingSet): | 128 for i, stmt in enumerate(augmentedWorkingSet): |
108 try: | 129 try: |
109 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) | 130 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) |
110 except Inconsistent: | 131 except Inconsistent: |
111 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') | 132 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') |
112 continue | 133 continue |
113 | 134 |
114 log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}') | 135 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') |
115 if outBinding.binding not in self._seenBindings: | 136 if outBinding.binding not in self._seenBindings: |
116 self._seenBindings.append(outBinding.binding.copy()) | 137 self._seenBindings.append(outBinding.binding.copy()) |
117 self._current = outBinding | 138 self._current = outBinding |
118 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') | 139 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') |
119 return | 140 return True |
120 log.debug(f'yes we saw') | 141 return False |
121 | 142 |
122 log.debug(f'{INDENT*6} {self} mines rules') | 143 def _advanceWithBoolRules(self) -> bool: |
144 log.debug(f'{INDENT*7} {self} mines bool rules') | |
145 if self.lhsStmt[1] == MATH['greaterThan']: | |
146 operands = [self.lhsStmt[0], self.lhsStmt[2]] | |
147 try: | |
148 boundOperands = self._boundOperands(operands) | |
149 except BindingUnknown: | |
150 return False | |
151 if numericNode(boundOperands[0]) > numericNode(boundOperands[1]): | |
152 bindingDict: Dict[BindableTerm, | |
153 Node] = self._prevBindings().copy() # no new values; just allow matching to keep going | |
154 if bindingDict not in self._seenBindings: | |
155 self._seenBindings.append(bindingDict) | |
156 self._current = CandidateBinding(bindingDict) | |
157 log.debug(f'{INDENT*7} new binding from {self} -> {bindingDict}') | |
158 return True | |
159 return False | |
160 | |
161 def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: | |
162 log.debug(f'{INDENT*7} {self} mines rules') | |
123 | 163 |
124 if self.lhsStmt[1] == ROOM['asFarenheit']: | 164 if self.lhsStmt[1] == ROOM['asFarenheit']: |
125 pb: Dict[BindableTerm, Node] = self._prevBindings() | 165 pb: Dict[BindableTerm, Node] = self._prevBindings() |
126 log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') | 166 log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') |
127 | 167 |
128 if self.lhsStmt[0] in pb: | 168 if self.lhsStmt[0] in pb: |
129 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] | 169 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] |
130 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) | 170 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) |
131 objVar = self.lhsStmt[2] | 171 objVar = self.lhsStmt[2] |
134 newBindings = {cast(BindableTerm, objVar): cast(Node, f)} | 174 newBindings = {cast(BindableTerm, objVar): cast(Node, f)} |
135 self._current.addNewBindings(CandidateBinding(newBindings)) | 175 self._current.addNewBindings(CandidateBinding(newBindings)) |
136 if newBindings not in self._seenBindings: | 176 if newBindings not in self._seenBindings: |
137 self._seenBindings.append(newBindings) | 177 self._seenBindings.append(newBindings) |
138 self._current = CandidateBinding(newBindings) | 178 self._current = CandidateBinding(newBindings) |
139 return | 179 return True |
140 | 180 elif self.lhsStmt[1] == MATH['sum']: |
141 log.debug(f'{INDENT*6} {self} is past end') | 181 |
142 self._pastEnd = True | 182 g = Graph() |
183 for s in boundFullWorkingSet: | |
184 g.add(s) | |
185 log.debug(f' boundWorkingSet graph: {s}') | |
186 log.debug(f'_parseList subj = {lhsStmtBound[0]}') | |
187 operands, _ = parseList(g, lhsStmtBound[0]) | |
188 log.debug(f'********* {INDENT*7} {self} found list {operands=}') | |
189 try: | |
190 obj = Literal(sum(map(numericNode, operands))) | |
191 except TypeError: | |
192 log.debug('typeerr in operands') | |
193 pass | |
194 else: | |
195 objVar = lhsStmtBound[2] | |
196 log.debug(f'{objVar=}') | |
197 | |
198 if not isinstance(objVar, Variable): | |
199 raise TypeError(f'expected Variable, got {objVar!r}') | |
200 newBindings: Dict[BindableTerm, Node] = {objVar: obj} | |
201 log.debug(f'{newBindings=}') | |
202 | |
203 self._current.addNewBindings(CandidateBinding(newBindings)) | |
204 log.debug(f'{self._seenBindings=}') | |
205 if newBindings not in self._seenBindings: | |
206 self._seenBindings.append(newBindings) | |
207 self._current = CandidateBinding(newBindings) | |
208 return True | |
209 | |
210 return False | |
211 | |
212 def _boundOperands(self, operands) -> List[Node]: | |
213 pb: Dict[BindableTerm, Node] = self._prevBindings() | |
214 | |
215 boundOperands: List[Node] = [] | |
216 for op in operands: | |
217 if isinstance(op, (Variable, BNode)): | |
218 if op in pb: | |
219 boundOperands.append(pb[op]) | |
220 else: | |
221 raise BindingUnknown() | |
222 else: | |
223 boundOperands.append(op) | |
224 return boundOperands | |
143 | 225 |
144 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: | 226 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: |
145 outBinding = self._prevBindings().copy() | 227 outBinding = self._prevBindings().copy() |
146 for rt, ct in zip(self.lhsStmt, newStmt): | 228 for rt, ct in zip(self.lhsStmt, newStmt): |
147 if isinstance(rt, (Variable, BNode)): | 229 if isinstance(rt, (Variable, BNode)): |
148 if rt in outBinding and outBinding[rt] != ct: | 230 if rt in outBinding and outBinding[rt] != ct: |
149 raise Inconsistent() | 231 raise Inconsistent(f'{rt=} {ct=} {outBinding=}') |
150 outBinding[rt] = ct | 232 outBinding[rt] = ct |
151 return CandidateBinding(outBinding) | 233 return CandidateBinding(outBinding) |
152 | 234 |
153 def currentBinding(self) -> CandidateBinding: | 235 def currentBinding(self) -> CandidateBinding: |
154 if self.pastEnd(): | 236 if self.pastEnd(): |
243 | 325 |
244 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: | 326 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: |
245 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all | 327 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all |
246 start out valid (or else raise NoOptions)""" | 328 start out valid (or else raise NoOptions)""" |
247 | 329 |
248 stmtsToAdd = list(self.graph) | 330 usedByFuncs: Set[Triple] = set() # don't worry about matching these |
331 stmtsToResolve = list(self.graph) | |
332 for i, s in enumerate(stmtsToResolve): | |
333 if s[1] == MATH['sum']: | |
334 _, used = parseList(self.graph, s[0]) | |
335 usedByFuncs.update(used) | |
336 | |
337 stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs] | |
338 | |
339 # sort them by variable dependencies; don't just try all perms! | |
340 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. | |
341 (s, p, o) = stmt | |
342 return p == MATH['sum'], p, s, o | |
343 | |
344 stmtsToAdd.sort(key=lightSortKey) | |
249 | 345 |
250 for perm in itertools.permutations(stmtsToAdd): | 346 for perm in itertools.permutations(stmtsToAdd): |
251 stmtStack: List[StmtLooper] = [] | 347 stmtStack: List[StmtLooper] = [] |
252 prev: Optional[StmtLooper] = None | 348 prev: Optional[StmtLooper] = None |
253 log.debug(f'{INDENT*5} try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') | 349 log.debug(f'{INDENT*5} try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') |
254 | 350 |
255 for s in perm: | 351 for s in perm: |
256 try: | 352 try: |
257 elem = StmtLooper(s, prev, knownTrue) | 353 elem = StmtLooper(s, prev, knownTrue, parent=self) |
258 except NoOptions: | 354 except NoOptions: |
259 log.debug(f'{INDENT*6} permutation didnt work, try another') | 355 log.debug(f'{INDENT*6} permutation didnt work, try another') |
260 break | 356 break |
261 stmtStack.append(elem) | 357 stmtStack.append(elem) |
262 prev = stmtStack[-1] | 358 prev = stmtStack[-1] |
538 log.debug(f'{INDENT*3} ({j}) {stmt}') | 634 log.debug(f'{INDENT*3} ({j}) {stmt}') |
539 | 635 |
540 log.debug('') | 636 log.debug('') |
541 log.debug(f'{INDENT*2}-applying rule {i}') | 637 log.debug(f'{INDENT*2}-applying rule {i}') |
542 log.debug(f'{INDENT*3} rule def lhs:') | 638 log.debug(f'{INDENT*3} rule def lhs:') |
543 for stmt in r.lhsGraph: | 639 for stmt in sorted(r.lhsGraph, reverse=True): |
544 log.debug(f'{INDENT*4} {stmt}') | 640 log.debug(f'{INDENT*4} {stmt}') |
545 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') | 641 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') |
546 | 642 |
547 | 643 |
548 def graphDump(g: Union[Graph, List[Triple]]): | 644 def graphDump(g: Union[Graph, List[Triple]]): |