Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1637:ec3f98d0c1d8
refactor rules eval
author | drewp@bigasterisk.com |
---|---|
date | Mon, 13 Sep 2021 01:36:06 -0700 |
parents | 3252bdc284bc |
children | 0ba1625037ae |
comparison
equal
deleted
inserted
replaced
1636:3252bdc284bc | 1637:ec3f98d0c1d8 |
---|---|
10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) | 10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) |
11 | 11 |
12 from prometheus_client import Histogram, Summary | 12 from prometheus_client import Histogram, Summary |
13 from rdflib import RDF, BNode, Graph, Namespace | 13 from rdflib import RDF, BNode, Graph, Namespace |
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
15 from rdflib.term import Literal, Node, Variable | 15 from rdflib.term import Node, Variable |
16 | 16 |
17 from candidate_binding import CandidateBinding | 17 from candidate_binding import CandidateBinding |
18 from inference_types import (BindableTerm, BindingUnknown, ReadOnlyWorkingSet, Triple) | 18 from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple |
19 from lhs_evaluation import Decimal, numericNode, parseList | 19 from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs |
20 | 20 |
21 log = logging.getLogger('infer') | 21 log = logging.getLogger('infer') |
22 INDENT = ' ' | 22 INDENT = ' ' |
23 | 23 |
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') | 24 INFER_CALLS = Summary('inference_infer_calls', 'calls') |
102 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') | 102 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') |
103 | 103 |
104 if self._advanceWithPlainMatches(augmentedWorkingSet): | 104 if self._advanceWithPlainMatches(augmentedWorkingSet): |
105 return | 105 return |
106 | 106 |
107 if self._advanceWithBoolRules(): | |
108 return | |
109 | |
110 curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) | 107 curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) |
111 [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) | 108 [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) |
112 | 109 |
113 fullWorkingSet = self.workingSet + self.parent.graph | 110 fullWorkingSet = self.workingSet + self.parent.graph |
114 boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False)) | 111 boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False)) |
123 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: | 120 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: |
124 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') | 121 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') |
125 for s in augmentedWorkingSet: | 122 for s in augmentedWorkingSet: |
126 log.debug(f'{INDENT*7} {s}') | 123 log.debug(f'{INDENT*7} {s}') |
127 | 124 |
128 for i, stmt in enumerate(augmentedWorkingSet): | 125 for stmt in augmentedWorkingSet: |
129 try: | 126 try: |
130 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) | 127 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) |
131 except Inconsistent: | 128 except Inconsistent: |
132 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') | 129 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') |
133 continue | 130 continue |
138 self._current = outBinding | 135 self._current = outBinding |
139 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') | 136 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') |
140 return True | 137 return True |
141 return False | 138 return False |
142 | 139 |
143 def _advanceWithBoolRules(self) -> bool: | 140 def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: |
144 log.debug(f'{INDENT*7} {self} mines bool rules') | 141 pred: Node = self.lhsStmt[1] |
145 if self.lhsStmt[1] == MATH['greaterThan']: | 142 |
146 operands = [self.lhsStmt[0], self.lhsStmt[2]] | 143 for functionType in functionsFor(pred): |
144 fn = functionType(self.lhsStmt, self.parent.graph) | |
147 try: | 145 try: |
148 boundOperands = self._boundOperands(operands) | 146 out = fn.bind(self._prevBindings()) |
149 except BindingUnknown: | 147 except BindingUnknown: |
150 return False | |
151 if numericNode(boundOperands[0]) > numericNode(boundOperands[1]): | |
152 binding: CandidateBinding = self._prevBindings().copy() # no new values; just allow matching to keep going | |
153 if binding not in self._seenBindings: | |
154 self._seenBindings.append(binding) | |
155 self._current = binding | |
156 log.debug(f'{INDENT*7} new binding from {self} -> {binding}') | |
157 return True | |
158 return False | |
159 | |
160 def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: | |
161 log.debug(f'{INDENT*7} {self} mines rules') | |
162 | |
163 if self.lhsStmt[1] == ROOM['asFarenheit']: | |
164 pb: CandidateBinding = self._prevBindings() | |
165 log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') | |
166 | |
167 if isinstance(self.lhsStmt[0], (Variable, BNode)) and pb.contains(self.lhsStmt[0]): | |
168 operands = [pb.applyTerm(self.lhsStmt[0])] | |
169 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) | |
170 objVar = self.lhsStmt[2] | |
171 if not isinstance(objVar, Variable): | |
172 raise TypeError(f'expected Variable, got {objVar!r}') | |
173 newBindings = CandidateBinding({cast(BindableTerm, objVar): cast(Node, f)}) | |
174 self._current.addNewBindings(newBindings) | |
175 if newBindings not in self._seenBindings: | |
176 self._seenBindings.append(newBindings) | |
177 self._current = newBindings | |
178 return True | |
179 elif self.lhsStmt[1] == MATH['sum']: | |
180 | |
181 g = Graph() | |
182 for s in boundFullWorkingSet: | |
183 g.add(s) | |
184 log.debug(f' boundWorkingSet graph: {s}') | |
185 log.debug(f'_parseList subj = {lhsStmtBound[0]}') | |
186 operands, _ = parseList(g, lhsStmtBound[0]) | |
187 log.debug(f'********* {INDENT*7} {self} found list {operands=}') | |
188 try: | |
189 obj = Literal(sum(map(numericNode, operands))) | |
190 except TypeError: | |
191 log.debug('typeerr in operands') | |
192 pass | 148 pass |
193 else: | 149 else: |
194 objVar = lhsStmtBound[2] | 150 if out is not None: |
195 log.debug(f'{objVar=}') | 151 binding: CandidateBinding = self._prevBindings().copy() |
196 | 152 binding.addNewBindings(out) |
197 if not isinstance(objVar, Variable): | 153 if binding not in self._seenBindings: |
198 raise TypeError(f'expected Variable, got {objVar!r}') | 154 self._seenBindings.append(binding) |
199 newBindings = CandidateBinding({objVar: obj}) | 155 self._current = binding |
200 log.debug(f'{newBindings=}') | 156 log.debug(f'{INDENT*7} new binding from {self} -> {binding}') |
201 | 157 return True |
202 self._current.addNewBindings(newBindings) | |
203 log.debug(f'{self._seenBindings=}') | |
204 if newBindings not in self._seenBindings: | |
205 self._seenBindings.append(newBindings) | |
206 self._current = newBindings | |
207 return True | |
208 | 158 |
209 return False | 159 return False |
210 | 160 |
211 def _boundOperands(self, operands) -> List[Node]: | 161 def _boundOperands(self, operands) -> List[Node]: |
212 pb: CandidateBinding = self._prevBindings() | 162 pb: CandidateBinding = self._prevBindings() |
300 | 250 |
301 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: | 251 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: |
302 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all | 252 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all |
303 start out valid (or else raise NoOptions)""" | 253 start out valid (or else raise NoOptions)""" |
304 | 254 |
305 usedByFuncs: Set[Triple] = set() # don't worry about matching these | 255 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) |
306 stmtsToResolve = list(self.graph) | 256 |
307 for i, s in enumerate(stmtsToResolve): | 257 stmtsToAdd = list(self.graph - usedByFuncs) |
308 if s[1] == MATH['sum']: | |
309 _, used = parseList(self.graph, s[0]) | |
310 usedByFuncs.update(used) | |
311 | |
312 stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs] | |
313 | 258 |
314 # sort them by variable dependencies; don't just try all perms! | 259 # sort them by variable dependencies; don't just try all perms! |
315 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. | 260 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. |
316 (s, p, o) = stmt | 261 (s, p, o) = stmt |
317 return p == MATH['sum'], p, s, o | 262 return p == MATH['sum'], p, s, o |
476 log.debug(f'{INDENT*4} {stmt}') | 421 log.debug(f'{INDENT*4} {stmt}') |
477 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') | 422 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') |
478 | 423 |
479 | 424 |
480 def graphDump(g: Union[Graph, List[Triple]]): | 425 def graphDump(g: Union[Graph, List[Triple]]): |
426 # this is very slow- debug only! | |
427 if not log.isEnabledFor(logging.DEBUG): | |
428 return "(skipped dump)" | |
481 if not isinstance(g, Graph): | 429 if not isinstance(g, Graph): |
482 g2 = Graph() | 430 g2 = Graph() |
483 g2 += g | 431 g2 += g |
484 g = g2 | 432 g = g2 |
485 g.bind('', ROOM) | 433 g.bind('', ROOM) |