comparison service/mqtt_to_rdf/inference.py @ 1637:ec3f98d0c1d8

refactor rules eval
author drewp@bigasterisk.com
date Mon, 13 Sep 2021 01:36:06 -0700
parents 3252bdc284bc
children 0ba1625037ae
comparison
equal deleted inserted replaced
1636:3252bdc284bc 1637:ec3f98d0c1d8
10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) 10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast)
11 11
12 from prometheus_client import Histogram, Summary 12 from prometheus_client import Histogram, Summary
13 from rdflib import RDF, BNode, Graph, Namespace 13 from rdflib import RDF, BNode, Graph, Namespace
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
15 from rdflib.term import Literal, Node, Variable 15 from rdflib.term import Node, Variable
16 16
17 from candidate_binding import CandidateBinding 17 from candidate_binding import CandidateBinding
18 from inference_types import (BindableTerm, BindingUnknown, ReadOnlyWorkingSet, Triple) 18 from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple
19 from lhs_evaluation import Decimal, numericNode, parseList 19 from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs
20 20
21 log = logging.getLogger('infer') 21 log = logging.getLogger('infer')
22 INDENT = ' ' 22 INDENT = ' '
23 23
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') 24 INFER_CALLS = Summary('inference_infer_calls', 'calls')
102 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') 102 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}')
103 103
104 if self._advanceWithPlainMatches(augmentedWorkingSet): 104 if self._advanceWithPlainMatches(augmentedWorkingSet):
105 return 105 return
106 106
107 if self._advanceWithBoolRules():
108 return
109
110 curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) 107 curBind = self.prev.currentBinding() if self.prev else CandidateBinding({})
111 [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) 108 [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False)
112 109
113 fullWorkingSet = self.workingSet + self.parent.graph 110 fullWorkingSet = self.workingSet + self.parent.graph
114 boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False)) 111 boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False))
123 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: 120 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool:
124 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') 121 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
125 for s in augmentedWorkingSet: 122 for s in augmentedWorkingSet:
126 log.debug(f'{INDENT*7} {s}') 123 log.debug(f'{INDENT*7} {s}')
127 124
128 for i, stmt in enumerate(augmentedWorkingSet): 125 for stmt in augmentedWorkingSet:
129 try: 126 try:
130 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) 127 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
131 except Inconsistent: 128 except Inconsistent:
132 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') 129 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
133 continue 130 continue
138 self._current = outBinding 135 self._current = outBinding
139 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') 136 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}')
140 return True 137 return True
141 return False 138 return False
142 139
143 def _advanceWithBoolRules(self) -> bool: 140 def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool:
144 log.debug(f'{INDENT*7} {self} mines bool rules') 141 pred: Node = self.lhsStmt[1]
145 if self.lhsStmt[1] == MATH['greaterThan']: 142
146 operands = [self.lhsStmt[0], self.lhsStmt[2]] 143 for functionType in functionsFor(pred):
144 fn = functionType(self.lhsStmt, self.parent.graph)
147 try: 145 try:
148 boundOperands = self._boundOperands(operands) 146 out = fn.bind(self._prevBindings())
149 except BindingUnknown: 147 except BindingUnknown:
150 return False
151 if numericNode(boundOperands[0]) > numericNode(boundOperands[1]):
152 binding: CandidateBinding = self._prevBindings().copy() # no new values; just allow matching to keep going
153 if binding not in self._seenBindings:
154 self._seenBindings.append(binding)
155 self._current = binding
156 log.debug(f'{INDENT*7} new binding from {self} -> {binding}')
157 return True
158 return False
159
160 def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool:
161 log.debug(f'{INDENT*7} {self} mines rules')
162
163 if self.lhsStmt[1] == ROOM['asFarenheit']:
164 pb: CandidateBinding = self._prevBindings()
165 log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
166
167 if isinstance(self.lhsStmt[0], (Variable, BNode)) and pb.contains(self.lhsStmt[0]):
168 operands = [pb.applyTerm(self.lhsStmt[0])]
169 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32))
170 objVar = self.lhsStmt[2]
171 if not isinstance(objVar, Variable):
172 raise TypeError(f'expected Variable, got {objVar!r}')
173 newBindings = CandidateBinding({cast(BindableTerm, objVar): cast(Node, f)})
174 self._current.addNewBindings(newBindings)
175 if newBindings not in self._seenBindings:
176 self._seenBindings.append(newBindings)
177 self._current = newBindings
178 return True
179 elif self.lhsStmt[1] == MATH['sum']:
180
181 g = Graph()
182 for s in boundFullWorkingSet:
183 g.add(s)
184 log.debug(f' boundWorkingSet graph: {s}')
185 log.debug(f'_parseList subj = {lhsStmtBound[0]}')
186 operands, _ = parseList(g, lhsStmtBound[0])
187 log.debug(f'********* {INDENT*7} {self} found list {operands=}')
188 try:
189 obj = Literal(sum(map(numericNode, operands)))
190 except TypeError:
191 log.debug('typeerr in operands')
192 pass 148 pass
193 else: 149 else:
194 objVar = lhsStmtBound[2] 150 if out is not None:
195 log.debug(f'{objVar=}') 151 binding: CandidateBinding = self._prevBindings().copy()
196 152 binding.addNewBindings(out)
197 if not isinstance(objVar, Variable): 153 if binding not in self._seenBindings:
198 raise TypeError(f'expected Variable, got {objVar!r}') 154 self._seenBindings.append(binding)
199 newBindings = CandidateBinding({objVar: obj}) 155 self._current = binding
200 log.debug(f'{newBindings=}') 156 log.debug(f'{INDENT*7} new binding from {self} -> {binding}')
201 157 return True
202 self._current.addNewBindings(newBindings)
203 log.debug(f'{self._seenBindings=}')
204 if newBindings not in self._seenBindings:
205 self._seenBindings.append(newBindings)
206 self._current = newBindings
207 return True
208 158
209 return False 159 return False
210 160
211 def _boundOperands(self, operands) -> List[Node]: 161 def _boundOperands(self, operands) -> List[Node]:
212 pb: CandidateBinding = self._prevBindings() 162 pb: CandidateBinding = self._prevBindings()
300 250
301 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: 251 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]:
302 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all 252 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
303 start out valid (or else raise NoOptions)""" 253 start out valid (or else raise NoOptions)"""
304 254
305 usedByFuncs: Set[Triple] = set() # don't worry about matching these 255 usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
306 stmtsToResolve = list(self.graph) 256
307 for i, s in enumerate(stmtsToResolve): 257 stmtsToAdd = list(self.graph - usedByFuncs)
308 if s[1] == MATH['sum']:
309 _, used = parseList(self.graph, s[0])
310 usedByFuncs.update(used)
311
312 stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs]
313 258
314 # sort them by variable dependencies; don't just try all perms! 259 # sort them by variable dependencies; don't just try all perms!
315 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. 260 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases.
316 (s, p, o) = stmt 261 (s, p, o) = stmt
317 return p == MATH['sum'], p, s, o 262 return p == MATH['sum'], p, s, o
476 log.debug(f'{INDENT*4} {stmt}') 421 log.debug(f'{INDENT*4} {stmt}')
477 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') 422 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
478 423
479 424
480 def graphDump(g: Union[Graph, List[Triple]]): 425 def graphDump(g: Union[Graph, List[Triple]]):
426 # this is very slow- debug only!
427 if not log.isEnabledFor(logging.DEBUG):
428 return "(skipped dump)"
481 if not isinstance(g, Graph): 429 if not isinstance(g, Graph):
482 g2 = Graph() 430 g2 = Graph()
483 g2 += g 431 g2 += g
484 g = g2 432 g = g2
485 g.bind('', ROOM) 433 g.bind('', ROOM)