comparison service/mqtt_to_rdf/inference.py @ 1588:0757fafbfdab

WIP inferencer - partial var and function support
author drewp@bigasterisk.com
date Thu, 02 Sep 2021 01:58:31 -0700
parents 9a3a18c494f9
children 5c1055be3c36
comparison
equal deleted inserted replaced
1587:9a3a18c494f9 1588:0757fafbfdab
1 """ 1 """
2 copied from reasoning 2021-08-29. probably same api. should 2 copied from reasoning 2021-08-29. probably same api. should
3 be able to lib/ this out 3 be able to lib/ this out
4 """ 4 """
5 5 import itertools
6 import logging 6 import logging
7 from typing import Dict, Tuple
8 from dataclasses import dataclass 7 from dataclasses import dataclass
8 from decimal import Decimal
9 from typing import Dict, Iterator, List, Set, Tuple, cast
10 from urllib.request import OpenerDirector
11
9 from prometheus_client import Summary 12 from prometheus_client import Summary
10 from rdflib import Graph, Namespace 13 from rdflib import BNode, Graph, Literal, Namespace
14 from rdflib.collection import Collection
11 from rdflib.graph import ConjunctiveGraph 15 from rdflib.graph import ConjunctiveGraph
12 from rdflib.term import Node, Variable 16 from rdflib.term import Node, Variable
13 17
14 log = logging.getLogger('infer') 18 log = logging.getLogger('infer')
15 19
55 bailout_iterations = 100 59 bailout_iterations = 100
56 delta = 1 60 delta = 1
57 while delta > 0 and bailout_iterations > 0: 61 while delta > 0 and bailout_iterations > 0:
58 bailout_iterations -= 1 62 bailout_iterations -= 1
59 delta = -len(implied) 63 delta = -len(implied)
60 self._iterateRules(workingSet, implied) 64 self._iterateAllRules(workingSet, implied)
61 delta += len(implied) 65 delta += len(implied)
62 log.info(f' this inference round added {delta} more implied stmts') 66 log.info(f' this inference round added {delta} more implied stmts')
63 log.info(f'{len(implied)} stmts implied:') 67 log.info(f'{len(implied)} stmts implied:')
64 for st in implied: 68 for st in implied:
65 log.info(f' {st}') 69 log.info(f' {st}')
66 return implied 70 return implied
67 71
68 def _iterateRules(self, workingSet, implied): 72 def _iterateAllRules(self, workingSet, implied):
69 for r in self.rules: 73 for r in self.rules:
70 if r[1] == LOG['implies']: 74 if r[1] == LOG['implies']:
71 self._applyRule(r[0], r[2], workingSet, implied) 75 applyRule(r[0], r[2], workingSet, implied)
72 else: 76 else:
73 log.info(f' {r} not a rule?') 77 log.info(f' {r} not a rule?')
74 78
75 def _applyRule(self, lhs, rhs, workingSet, implied): 79
76 containsSetup = self._containsSetup(lhs, workingSet) 80 def applyRule(lhs: Graph, rhs: Graph, workingSet, implied):
77 if containsSetup: 81 for bindings in findCandidateBindings(lhs, workingSet):
78 for st in rhs: 82 log.debug(f' - rule gave {bindings=}')
79 workingSet.add(st) 83 for newStmt in withBinding(rhs, bindings):
80 implied.add(st) 84 workingSet.add(newStmt)
81 85 implied.add(newStmt)
82 def _containsSetup(self, lhs, workingSet): 86
83 return all(st in workingSet for st in lhs) 87
88 def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]:
89 varsToBind: Set[Variable] = set()
90 staticRuleStmts = []
91 for ruleStmt in lhs:
92 varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)]
93 varsToBind.update(varsInStmt)
94 if (not varsInStmt # ok
95 and not any(isinstance(t, BNode) for t in ruleStmt) # approx
96 ):
97 staticRuleStmts.append(ruleStmt)
98
99 if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
100 log.debug('static shortcircuit')
101 return
102
103 # the total set of terms each variable could possibly match
104 candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet)
105
106 orderedVars, orderedValueSets = organize(candidateTermMatches)
107
108 log.debug(f' {orderedVars=}')
109 log.debug(f'{orderedValueSets=}')
110
111 for perm in itertools.product(*orderedValueSets):
112 binding: Dict[Variable, Node] = dict(zip(orderedVars, perm))
113 log.debug(f'{binding=} but lets look for funcs')
114 for v, val in inferredFuncBindings(lhs, binding): # loop this until it's done
115 log.debug(f'ifb tells us {v}={val}')
116 binding[v] = val
117 if not verifyBinding(lhs, binding, workingSet): # fix this
118 log.debug(f'verify culls')
119 continue
120 yield binding
121
122
123 def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]:
124 for stmt in lhs:
125 if stmt[1] not in inferredFuncs:
126 continue
127 if not isinstance(stmt[2], Variable):
128 continue
129
130 x = stmt[0]
131 if isinstance(x, Variable):
132 x = bindingsBefore[x]
133 yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore)
134
135
136 def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]:
137 candidateTermMatches: Dict[Variable, Set[Node]] = {}
138
139 for r in lhs:
140 for w in workingSet:
141 bindingsFromStatement: Dict[Variable, Set[Node]] = {}
142 for rterm, wterm in zip(r, w):
143 if isinstance(rterm, Variable):
144 bindingsFromStatement.setdefault(rterm, set()).add(wterm)
145 elif rterm != wterm:
146 break
147 else:
148 for v, vals in bindingsFromStatement.items():
149 candidateTermMatches.setdefault(v, set()).update(vals)
150 return candidateTermMatches
151
152
153 def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]:
154 for stmt in rhs:
155 stmt = list(stmt)
156 for i, t in enumerate(stmt):
157 if isinstance(t, Variable):
158 try:
159 stmt[i] = bindings[t]
160 except KeyError:
161 # stmt is from another rule that we're not applying right now
162 break
163 else:
164 yield cast(Triple, stmt)
165
166
167 def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool:
168 for stmt in withBinding(lhs, binding):
169 log.debug(f'lhs verify {stmt}')
170 if stmt[1] in filterFuncs:
171 if not mathTest(*stmt):
172 return False
173 elif stmt not in workingSet and stmt[1] not in inferredFuncs:
174 log.debug(f' ver culls here')
175 return False
176 return True
177
178
179 inferredFuncs = {
180 ROOM['asFarenheit'],
181 MATH['sum'],
182 }
183 filterFuncs = {
184 MATH['greaterThan'],
185 }
186
187
188 def inferredFuncObject(subj, pred, graph, bindings):
189 if pred == ROOM['asFarenheit']:
190 return Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
191 elif pred == MATH['sum']:
192 operands = Collection(graph, subj)
193 # shouldn't be redoing this here
194 operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
195 log.debug(f' sum {list(operands)}')
196 return Literal(sum(op.toPython() for op in operands))
197
198 else:
199 raise NotImplementedError(pred)
200
201
202 def mathTest(subj, pred, obj):
203 x = subj.toPython()
204 y = obj.toPython()
205 if pred == MATH['greaterThan']:
206 return x > y
207 else:
208 raise NotImplementedError(pred)
209
210
211 def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]:
212 items = list(candidateTermMatches.items())
213 items.sort()
214 orderedVars: List[Variable] = []
215 orderedValueSets: List[List[Node]] = []
216 for v, vals in items:
217 orderedVars.append(v)
218 orderedValues: List[Node] = list(vals)
219 orderedValues.sort(key=str)
220 orderedValueSets.append(orderedValues)
221
222 return orderedVars, orderedValueSets
223
224
225 def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
226 for ruleStmt in staticRuleStmts:
227 if ruleStmt not in workingSet:
228 return True
229 return False