Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1588:0757fafbfdab
WIP inferencer - partial var and function support
author | drewp@bigasterisk.com |
---|---|
date | Thu, 02 Sep 2021 01:58:31 -0700 |
parents | 9a3a18c494f9 |
children | 5c1055be3c36 |
comparison
equal
deleted
inserted
replaced
1587:9a3a18c494f9 | 1588:0757fafbfdab |
---|---|
1 """ | 1 """ |
2 copied from reasoning 2021-08-29. probably same api. should | 2 copied from reasoning 2021-08-29. probably same api. should |
3 be able to lib/ this out | 3 be able to lib/ this out |
4 """ | 4 """ |
5 | 5 import itertools |
6 import logging | 6 import logging |
7 from typing import Dict, Tuple | |
8 from dataclasses import dataclass | 7 from dataclasses import dataclass |
8 from decimal import Decimal | |
9 from typing import Dict, Iterator, List, Set, Tuple, cast | |
10 from urllib.request import OpenerDirector | |
11 | |
9 from prometheus_client import Summary | 12 from prometheus_client import Summary |
10 from rdflib import Graph, Namespace | 13 from rdflib import BNode, Graph, Literal, Namespace |
14 from rdflib.collection import Collection | |
11 from rdflib.graph import ConjunctiveGraph | 15 from rdflib.graph import ConjunctiveGraph |
12 from rdflib.term import Node, Variable | 16 from rdflib.term import Node, Variable |
13 | 17 |
14 log = logging.getLogger('infer') | 18 log = logging.getLogger('infer') |
15 | 19 |
55 bailout_iterations = 100 | 59 bailout_iterations = 100 |
56 delta = 1 | 60 delta = 1 |
57 while delta > 0 and bailout_iterations > 0: | 61 while delta > 0 and bailout_iterations > 0: |
58 bailout_iterations -= 1 | 62 bailout_iterations -= 1 |
59 delta = -len(implied) | 63 delta = -len(implied) |
60 self._iterateRules(workingSet, implied) | 64 self._iterateAllRules(workingSet, implied) |
61 delta += len(implied) | 65 delta += len(implied) |
62 log.info(f' this inference round added {delta} more implied stmts') | 66 log.info(f' this inference round added {delta} more implied stmts') |
63 log.info(f'{len(implied)} stmts implied:') | 67 log.info(f'{len(implied)} stmts implied:') |
64 for st in implied: | 68 for st in implied: |
65 log.info(f' {st}') | 69 log.info(f' {st}') |
66 return implied | 70 return implied |
67 | 71 |
68 def _iterateRules(self, workingSet, implied): | 72 def _iterateAllRules(self, workingSet, implied): |
69 for r in self.rules: | 73 for r in self.rules: |
70 if r[1] == LOG['implies']: | 74 if r[1] == LOG['implies']: |
71 self._applyRule(r[0], r[2], workingSet, implied) | 75 applyRule(r[0], r[2], workingSet, implied) |
72 else: | 76 else: |
73 log.info(f' {r} not a rule?') | 77 log.info(f' {r} not a rule?') |
74 | 78 |
75 def _applyRule(self, lhs, rhs, workingSet, implied): | 79 |
76 containsSetup = self._containsSetup(lhs, workingSet) | 80 def applyRule(lhs: Graph, rhs: Graph, workingSet, implied): |
77 if containsSetup: | 81 for bindings in findCandidateBindings(lhs, workingSet): |
78 for st in rhs: | 82 log.debug(f' - rule gave {bindings=}') |
79 workingSet.add(st) | 83 for newStmt in withBinding(rhs, bindings): |
80 implied.add(st) | 84 workingSet.add(newStmt) |
81 | 85 implied.add(newStmt) |
82 def _containsSetup(self, lhs, workingSet): | 86 |
83 return all(st in workingSet for st in lhs) | 87 |
88 def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]: | |
89 varsToBind: Set[Variable] = set() | |
90 staticRuleStmts = [] | |
91 for ruleStmt in lhs: | |
92 varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)] | |
93 varsToBind.update(varsInStmt) | |
94 if (not varsInStmt # ok | |
95 and not any(isinstance(t, BNode) for t in ruleStmt) # approx | |
96 ): | |
97 staticRuleStmts.append(ruleStmt) | |
98 | |
99 if someStaticStmtDoesntMatch(staticRuleStmts, workingSet): | |
100 log.debug('static shortcircuit') | |
101 return | |
102 | |
103 # the total set of terms each variable could possibly match | |
104 candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet) | |
105 | |
106 orderedVars, orderedValueSets = organize(candidateTermMatches) | |
107 | |
108 log.debug(f' {orderedVars=}') | |
109 log.debug(f'{orderedValueSets=}') | |
110 | |
111 for perm in itertools.product(*orderedValueSets): | |
112 binding: Dict[Variable, Node] = dict(zip(orderedVars, perm)) | |
113 log.debug(f'{binding=} but lets look for funcs') | |
114 for v, val in inferredFuncBindings(lhs, binding): # loop this until it's done | |
115 log.debug(f'ifb tells us {v}={val}') | |
116 binding[v] = val | |
117 if not verifyBinding(lhs, binding, workingSet): # fix this | |
118 log.debug(f'verify culls') | |
119 continue | |
120 yield binding | |
121 | |
122 | |
123 def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]: | |
124 for stmt in lhs: | |
125 if stmt[1] not in inferredFuncs: | |
126 continue | |
127 if not isinstance(stmt[2], Variable): | |
128 continue | |
129 | |
130 x = stmt[0] | |
131 if isinstance(x, Variable): | |
132 x = bindingsBefore[x] | |
133 yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore) | |
134 | |
135 | |
136 def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]: | |
137 candidateTermMatches: Dict[Variable, Set[Node]] = {} | |
138 | |
139 for r in lhs: | |
140 for w in workingSet: | |
141 bindingsFromStatement: Dict[Variable, Set[Node]] = {} | |
142 for rterm, wterm in zip(r, w): | |
143 if isinstance(rterm, Variable): | |
144 bindingsFromStatement.setdefault(rterm, set()).add(wterm) | |
145 elif rterm != wterm: | |
146 break | |
147 else: | |
148 for v, vals in bindingsFromStatement.items(): | |
149 candidateTermMatches.setdefault(v, set()).update(vals) | |
150 return candidateTermMatches | |
151 | |
152 | |
153 def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]: | |
154 for stmt in rhs: | |
155 stmt = list(stmt) | |
156 for i, t in enumerate(stmt): | |
157 if isinstance(t, Variable): | |
158 try: | |
159 stmt[i] = bindings[t] | |
160 except KeyError: | |
161 # stmt is from another rule that we're not applying right now | |
162 break | |
163 else: | |
164 yield cast(Triple, stmt) | |
165 | |
166 | |
167 def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool: | |
168 for stmt in withBinding(lhs, binding): | |
169 log.debug(f'lhs verify {stmt}') | |
170 if stmt[1] in filterFuncs: | |
171 if not mathTest(*stmt): | |
172 return False | |
173 elif stmt not in workingSet and stmt[1] not in inferredFuncs: | |
174 log.debug(f' ver culls here') | |
175 return False | |
176 return True | |
177 | |
178 | |
179 inferredFuncs = { | |
180 ROOM['asFarenheit'], | |
181 MATH['sum'], | |
182 } | |
183 filterFuncs = { | |
184 MATH['greaterThan'], | |
185 } | |
186 | |
187 | |
188 def inferredFuncObject(subj, pred, graph, bindings): | |
189 if pred == ROOM['asFarenheit']: | |
190 return Literal(Decimal(subj.toPython()) * 9 / 5 + 32) | |
191 elif pred == MATH['sum']: | |
192 operands = Collection(graph, subj) | |
193 # shouldn't be redoing this here | |
194 operands = [bindings[o] if isinstance(o, Variable) else o for o in operands] | |
195 log.debug(f' sum {list(operands)}') | |
196 return Literal(sum(op.toPython() for op in operands)) | |
197 | |
198 else: | |
199 raise NotImplementedError(pred) | |
200 | |
201 | |
202 def mathTest(subj, pred, obj): | |
203 x = subj.toPython() | |
204 y = obj.toPython() | |
205 if pred == MATH['greaterThan']: | |
206 return x > y | |
207 else: | |
208 raise NotImplementedError(pred) | |
209 | |
210 | |
211 def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]: | |
212 items = list(candidateTermMatches.items()) | |
213 items.sort() | |
214 orderedVars: List[Variable] = [] | |
215 orderedValueSets: List[List[Node]] = [] | |
216 for v, vals in items: | |
217 orderedVars.append(v) | |
218 orderedValues: List[Node] = list(vals) | |
219 orderedValues.sort(key=str) | |
220 orderedValueSets.append(orderedValues) | |
221 | |
222 return orderedVars, orderedValueSets | |
223 | |
224 | |
225 def someStaticStmtDoesntMatch(staticRuleStmts, workingSet): | |
226 for ruleStmt in staticRuleStmts: | |
227 if ruleStmt not in workingSet: | |
228 return True | |
229 return False |