comparison service/mqtt_to_rdf/inference.py @ 1634:ba59cfc3c747

hack math:sum in there. Test suite is passing except some slow performers
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 23:48:43 -0700
parents 6107603ed455
children 22d481f0a924
comparison
equal deleted inserted replaced
1633:6107603ed455 1634:ba59cfc3c747
5 import itertools 5 import itertools
6 import logging 6 import logging
7 import time 7 import time
8 from collections import defaultdict 8 from collections import defaultdict
9 from dataclasses import dataclass 9 from dataclasses import dataclass
10 from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast 10 from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast)
11 11
12 from prometheus_client import Histogram, Summary 12 from prometheus_client import Histogram, Summary
13 from rdflib import BNode, Graph, Namespace 13 from rdflib import RDF, BNode, Graph, Namespace
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
15 from rdflib.term import Literal, Node, Variable 15 from rdflib.term import Literal, Node, Variable
16 16
17 from candidate_binding import CandidateBinding 17 from candidate_binding import CandidateBinding
18 from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) 18 from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple)
19 from lhs_evaluation import Decimal, Evaluation, numericNode 19 from lhs_evaluation import Decimal, Evaluation, numericNode, parseList
20 20
21 log = logging.getLogger('infer') 21 log = logging.getLogger('infer')
22 INDENT = ' ' 22 INDENT = ' '
23 23
24 INFER_CALLS = Summary('inference_infer_calls', 'calls') 24 INFER_CALLS = Summary('inference_infer_calls', 'calls')
56 56
57 This iterator is restartable.""" 57 This iterator is restartable."""
58 lhsStmt: Triple 58 lhsStmt: Triple
59 prev: Optional['StmtLooper'] 59 prev: Optional['StmtLooper']
60 workingSet: ReadOnlyWorkingSet 60 workingSet: ReadOnlyWorkingSet
61 parent: 'Lhs' # just for lhs.graph, really
61 62
62 def __repr__(self): 63 def __repr__(self):
63 return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' 64 return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
64 65
65 def __post_init__(self): 66 def __post_init__(self):
96 augmentedWorkingSet = self._myWorkingSetMatches 97 augmentedWorkingSet = self._myWorkingSetMatches
97 else: 98 else:
98 augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, 99 augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches,
99 returnBoundStatementsOnly=False)) 100 returnBoundStatementsOnly=False))
100 101
101 log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}') 102 log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}')
102 103
103 log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements') 104 if self._advanceWithPlainMatches(augmentedWorkingSet):
105 return
106
107 if self._advanceWithBoolRules():
108 return
109
110 curBind = self.prev.currentBinding() if self.prev else CandidateBinding({})
111 [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False)
112
113 fullWorkingSet = self.workingSet + self.parent.graph
114 boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False))
115 log.debug(f'{fullWorkingSet.__len__()=} {len(boundFullWorkingSet)=}')
116
117 if self._advanceWithFunctions(augmentedWorkingSet, boundFullWorkingSet, lhsStmtBound):
118 return
119
120 log.debug(f'{INDENT*6} {self} is past end')
121 self._pastEnd = True
122
123 def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool:
124 log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
104 for s in augmentedWorkingSet: 125 for s in augmentedWorkingSet:
105 log.debug(f'{INDENT*7} {s}') 126 log.debug(f'{INDENT*7} {s}')
106 127
107 for i, stmt in enumerate(augmentedWorkingSet): 128 for i, stmt in enumerate(augmentedWorkingSet):
108 try: 129 try:
109 outBinding = self._totalBindingIfThisStmtWereTrue(stmt) 130 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
110 except Inconsistent: 131 except Inconsistent:
111 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') 132 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
112 continue 133 continue
113 134
114 log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}') 135 log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
115 if outBinding.binding not in self._seenBindings: 136 if outBinding.binding not in self._seenBindings:
116 self._seenBindings.append(outBinding.binding.copy()) 137 self._seenBindings.append(outBinding.binding.copy())
117 self._current = outBinding 138 self._current = outBinding
118 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') 139 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}')
119 return 140 return True
120 log.debug(f'yes we saw') 141 return False
121 142
122 log.debug(f'{INDENT*6} {self} mines rules') 143 def _advanceWithBoolRules(self) -> bool:
144 log.debug(f'{INDENT*7} {self} mines bool rules')
145 if self.lhsStmt[1] == MATH['greaterThan']:
146 operands = [self.lhsStmt[0], self.lhsStmt[2]]
147 try:
148 boundOperands = self._boundOperands(operands)
149 except BindingUnknown:
150 return False
151 if numericNode(boundOperands[0]) > numericNode(boundOperands[1]):
152 bindingDict: Dict[BindableTerm,
153 Node] = self._prevBindings().copy() # no new values; just allow matching to keep going
154 if bindingDict not in self._seenBindings:
155 self._seenBindings.append(bindingDict)
156 self._current = CandidateBinding(bindingDict)
157 log.debug(f'{INDENT*7} new binding from {self} -> {bindingDict}')
158 return True
159 return False
160
161 def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool:
162 log.debug(f'{INDENT*7} {self} mines rules')
123 163
124 if self.lhsStmt[1] == ROOM['asFarenheit']: 164 if self.lhsStmt[1] == ROOM['asFarenheit']:
125 pb: Dict[BindableTerm, Node] = self._prevBindings() 165 pb: Dict[BindableTerm, Node] = self._prevBindings()
126 log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') 166 log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
127 167
128 if self.lhsStmt[0] in pb: 168 if self.lhsStmt[0] in pb:
129 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] 169 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]]
130 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) 170 f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32))
131 objVar = self.lhsStmt[2] 171 objVar = self.lhsStmt[2]
134 newBindings = {cast(BindableTerm, objVar): cast(Node, f)} 174 newBindings = {cast(BindableTerm, objVar): cast(Node, f)}
135 self._current.addNewBindings(CandidateBinding(newBindings)) 175 self._current.addNewBindings(CandidateBinding(newBindings))
136 if newBindings not in self._seenBindings: 176 if newBindings not in self._seenBindings:
137 self._seenBindings.append(newBindings) 177 self._seenBindings.append(newBindings)
138 self._current = CandidateBinding(newBindings) 178 self._current = CandidateBinding(newBindings)
139 return 179 return True
140 180 elif self.lhsStmt[1] == MATH['sum']:
141 log.debug(f'{INDENT*6} {self} is past end') 181
142 self._pastEnd = True 182 g = Graph()
183 for s in boundFullWorkingSet:
184 g.add(s)
185 log.debug(f' boundWorkingSet graph: {s}')
186 log.debug(f'_parseList subj = {lhsStmtBound[0]}')
187 operands, _ = parseList(g, lhsStmtBound[0])
188 log.debug(f'********* {INDENT*7} {self} found list {operands=}')
189 try:
190 obj = Literal(sum(map(numericNode, operands)))
191 except TypeError:
192 log.debug('typeerr in operands')
193 pass
194 else:
195 objVar = lhsStmtBound[2]
196 log.debug(f'{objVar=}')
197
198 if not isinstance(objVar, Variable):
199 raise TypeError(f'expected Variable, got {objVar!r}')
200 newBindings: Dict[BindableTerm, Node] = {objVar: obj}
201 log.debug(f'{newBindings=}')
202
203 self._current.addNewBindings(CandidateBinding(newBindings))
204 log.debug(f'{self._seenBindings=}')
205 if newBindings not in self._seenBindings:
206 self._seenBindings.append(newBindings)
207 self._current = CandidateBinding(newBindings)
208 return True
209
210 return False
211
212 def _boundOperands(self, operands) -> List[Node]:
213 pb: Dict[BindableTerm, Node] = self._prevBindings()
214
215 boundOperands: List[Node] = []
216 for op in operands:
217 if isinstance(op, (Variable, BNode)):
218 if op in pb:
219 boundOperands.append(pb[op])
220 else:
221 raise BindingUnknown()
222 else:
223 boundOperands.append(op)
224 return boundOperands
143 225
144 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: 226 def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
145 outBinding = self._prevBindings().copy() 227 outBinding = self._prevBindings().copy()
146 for rt, ct in zip(self.lhsStmt, newStmt): 228 for rt, ct in zip(self.lhsStmt, newStmt):
147 if isinstance(rt, (Variable, BNode)): 229 if isinstance(rt, (Variable, BNode)):
148 if rt in outBinding and outBinding[rt] != ct: 230 if rt in outBinding and outBinding[rt] != ct:
149 raise Inconsistent() 231 raise Inconsistent(f'{rt=} {ct=} {outBinding=}')
150 outBinding[rt] = ct 232 outBinding[rt] = ct
151 return CandidateBinding(outBinding) 233 return CandidateBinding(outBinding)
152 234
153 def currentBinding(self) -> CandidateBinding: 235 def currentBinding(self) -> CandidateBinding:
154 if self.pastEnd(): 236 if self.pastEnd():
243 325
244 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]: 326 def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]:
245 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all 327 """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
246 start out valid (or else raise NoOptions)""" 328 start out valid (or else raise NoOptions)"""
247 329
248 stmtsToAdd = list(self.graph) 330 usedByFuncs: Set[Triple] = set() # don't worry about matching these
331 stmtsToResolve = list(self.graph)
332 for i, s in enumerate(stmtsToResolve):
333 if s[1] == MATH['sum']:
334 _, used = parseList(self.graph, s[0])
335 usedByFuncs.update(used)
336
337 stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs]
338
339 # sort them by variable dependencies; don't just try all perms!
340 def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases.
341 (s, p, o) = stmt
342 return p == MATH['sum'], p, s, o
343
344 stmtsToAdd.sort(key=lightSortKey)
249 345
250 for perm in itertools.permutations(stmtsToAdd): 346 for perm in itertools.permutations(stmtsToAdd):
251 stmtStack: List[StmtLooper] = [] 347 stmtStack: List[StmtLooper] = []
252 prev: Optional[StmtLooper] = None 348 prev: Optional[StmtLooper] = None
253 log.debug(f'{INDENT*5} try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}') 349 log.debug(f'{INDENT*5} try stmts in this order: {" -> ".join(graphDump([p]) for p in perm)}')
254 350
255 for s in perm: 351 for s in perm:
256 try: 352 try:
257 elem = StmtLooper(s, prev, knownTrue) 353 elem = StmtLooper(s, prev, knownTrue, parent=self)
258 except NoOptions: 354 except NoOptions:
259 log.debug(f'{INDENT*6} permutation didnt work, try another') 355 log.debug(f'{INDENT*6} permutation didnt work, try another')
260 break 356 break
261 stmtStack.append(elem) 357 stmtStack.append(elem)
262 prev = stmtStack[-1] 358 prev = stmtStack[-1]
538 log.debug(f'{INDENT*3} ({j}) {stmt}') 634 log.debug(f'{INDENT*3} ({j}) {stmt}')
539 635
540 log.debug('') 636 log.debug('')
541 log.debug(f'{INDENT*2}-applying rule {i}') 637 log.debug(f'{INDENT*2}-applying rule {i}')
542 log.debug(f'{INDENT*3} rule def lhs:') 638 log.debug(f'{INDENT*3} rule def lhs:')
543 for stmt in r.lhsGraph: 639 for stmt in sorted(r.lhsGraph, reverse=True):
544 log.debug(f'{INDENT*4} {stmt}') 640 log.debug(f'{INDENT*4} {stmt}')
545 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') 641 log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
546 642
547 643
548 def graphDump(g: Union[Graph, List[Triple]]): 644 def graphDump(g: Union[Graph, List[Triple]]):