Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/inference.py @ 1618:48bf62008c82
attempted to rewrite with CandidateTermMatches but it broke
author | drewp@bigasterisk.com |
---|---|
date | Wed, 08 Sep 2021 18:32:11 -0700 |
parents | 3a6ed545357f |
children |
comparison
equal
deleted
inserted
replaced
1617:e105032b0e3d | 1618:48bf62008c82 |
---|---|
5 import itertools | 5 import itertools |
6 import logging | 6 import logging |
7 import time | 7 import time |
8 from collections import defaultdict | 8 from collections import defaultdict |
9 from dataclasses import dataclass, field | 9 from dataclasses import dataclass, field |
10 from typing import Dict, Iterator, List, Set, Tuple, Union, cast | 10 from typing import Dict, Iterator, List, Literal, Optional, Set, Tuple, Union, cast |
11 | 11 |
12 from prometheus_client import Summary | 12 from prometheus_client import Summary |
13 from rdflib import BNode, Graph, Namespace, URIRef | 13 from rdflib import BNode, Graph, Namespace, URIRef |
14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate | 14 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate |
15 from rdflib.term import Node, Variable | 15 from rdflib.term import Node, Variable |
53 self.evaluations = list(Evaluation.findEvals(self.graph)) | 53 self.evaluations = list(Evaluation.findEvals(self.graph)) |
54 | 54 |
55 def __repr__(self): | 55 def __repr__(self): |
56 return f"Lhs({graphDump(self.graph)})" | 56 return f"Lhs({graphDump(self.graph)})" |
57 | 57 |
58 def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: | 58 def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: |
59 """bindings that fit the LHS of a rule, using statements from workingSet and functions | 59 """bindings that fit the LHS of a rule, using statements from workingSet and functions |
60 from LHS""" | 60 from LHS""" |
61 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') | 61 log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') |
62 stats['findCandidateBindingsCalls'] += 1 | 62 stats['findCandidateBindingsCalls'] += 1 |
63 | 63 |
64 if not self._allStaticStatementsMatch(workingSet): | 64 if not self._allStaticStatementsMatch(knownTrue): |
65 stats['findCandidateBindingEarlyExits'] += 1 | 65 stats['findCandidateBindingEarlyExits'] += 1 |
66 return | 66 return |
67 | 67 |
68 for binding in self._possibleBindings(workingSet, stats): | 68 boundSoFar = CandidateBinding({}) |
69 for binding in self._possibleBindings(knownTrue, boundSoFar, stats): | |
69 log.debug('') | 70 log.debug('') |
70 log.debug(f'{INDENT*4}*trying {binding.binding}') | 71 log.debug(f'{INDENT*4}*trying {binding.binding}') |
71 | 72 |
72 if not binding.verify(workingSet): | 73 if not binding.verify(knownTrue): |
73 log.debug(f'{INDENT*4} this binding did not verify') | 74 log.debug(f'{INDENT*4} this binding did not verify') |
74 stats['permCountFailingVerify'] += 1 | 75 stats['permCountFailingVerify'] += 1 |
75 continue | 76 continue |
76 | 77 |
77 stats['permCountSucceeding'] += 1 | 78 stats['permCountSucceeding'] += 1 |
78 yield binding | 79 yield binding |
79 | 80 boundSoFar.addNewBindings(binding.binding) |
80 def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool: | 81 |
82 def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: | |
81 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's | 83 # bug: see TestSelfFulfillingRule.test3 for a case where this rule's |
82 # static stmt is matched by a non-static stmt in the rule itself | 84 # static stmt is matched by a non-static stmt in the rule itself |
83 for ruleStmt in self.staticRuleStmts: | 85 for ruleStmt in self.staticRuleStmts: |
84 if ruleStmt not in workingSet: | 86 if ruleStmt not in knownTrue: |
85 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') | 87 log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') |
86 return False | 88 return False |
87 return True | 89 return True |
88 | 90 |
89 def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']: | 91 def _possibleBindings(self, workingSet, boundSoFar, stats) -> Iterator['BoundLhs']: |
90 """this yields at least the working bindings, and possibly others""" | 92 """this yields at least the working bindings, and possibly others""" |
91 candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet) | 93 for bindRow in self._product(workingSet, boundSoFar): |
92 for bindRow in self._product(candidateTermMatches): | |
93 try: | 94 try: |
94 yield BoundLhs(self, bindRow) | 95 yield BoundLhs(self, bindRow) |
95 except EvaluationFailed: | 96 except EvaluationFailed: |
96 stats['permCountFailingEval'] += 1 | 97 stats['permCountFailingEval'] += 1 |
97 | 98 |
98 def _product(self, candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Iterator[CandidateBinding]: | 99 def _product(self, workingSet, boundSoFar: CandidateBinding) -> Iterator[CandidateBinding]: |
99 orderedVars, orderedValueSets = _organize(candidateTermMatches) | 100 orderedVars = [] |
101 for stmt in self.graph: | |
102 for t in stmt: | |
103 if isinstance(t, (Variable, BNode)): | |
104 orderedVars.append(t) | |
105 orderedVars = sorted(set(orderedVars)) | |
106 | |
107 orderedValueSets = [] | |
108 for v in orderedVars: | |
109 orderedValueSets.append(CandidateTermMatches(v, self, workingSet, boundSoFar).results) | |
100 self._logCandidates(orderedVars, orderedValueSets) | 110 self._logCandidates(orderedVars, orderedValueSets) |
101 log.debug(f'{INDENT*3} trying all permutations:') | 111 log.debug(f'{INDENT*3} trying all permutations:') |
102 if not orderedValueSets: | 112 if not orderedVars: |
103 yield CandidateBinding({}) | 113 yield CandidateBinding({}) |
104 return | 114 return |
115 | |
105 if not orderedValueSets or not all(orderedValueSets): | 116 if not orderedValueSets or not all(orderedValueSets): |
106 # some var or bnode has no options at all | 117 # some var or bnode has no options at all |
107 return | 118 return |
108 rings: List[Iterator[Node]] = [itertools.cycle(valSet) for valSet in orderedValueSets] | 119 rings: List[Iterator[Node]] = [itertools.cycle(valSet) for valSet in orderedValueSets] |
109 currentSet: List[Node] = [next(ring) for ring in rings] | 120 currentSet: List[Node] = [next(ring) for ring in rings] |
110 starts = [valSet[-1] for valSet in orderedValueSets] | 121 starts = [valSet[-1] for valSet in orderedValueSets] |
111 while True: | 122 while True: |
112 for col, curr in enumerate(currentSet): | 123 for col, curr in enumerate(currentSet): |
113 currentSet[col] = next(rings[col]) | 124 currentSet[col] = next(rings[col]) |
114 log.debug(repr(currentSet)) | 125 log.debug(f'{INDENT*4} currentSet: {repr(currentSet)}') |
115 yield CandidateBinding(dict(zip(orderedVars, currentSet))) | 126 yield CandidateBinding(dict(zip(orderedVars, currentSet))) |
116 if curr is not starts[col]: | 127 if curr is not starts[col]: |
117 break | 128 break |
118 if col == len(orderedValueSets) - 1: | 129 if col == len(orderedValueSets) - 1: |
119 return | 130 return |
122 """the total set of terms each variable could possibly match""" | 133 """the total set of terms each variable could possibly match""" |
123 | 134 |
124 candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) | 135 candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) |
125 for lhsStmt in self.graph: | 136 for lhsStmt in self.graph: |
126 log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}') | 137 log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}') |
127 for i, trueStmt in enumerate(workingSet): | 138 for trueStmt in workingSet: |
128 # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}') | 139 # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}') |
129 | 140 |
130 for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt): | 141 for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt): |
131 candidateTermMatches[v].update(vals) | 142 candidateTermMatches[v].update(vals) |
132 | 143 |
133 for trueStmt in itertools.chain(workingSet, self.graph): | 144 # for trueStmt in itertools.chain(workingSet, self.graph): |
134 for b in self.lhsBnodes: | 145 # for b in self.lhsBnodes: |
135 for t in [trueStmt[0], trueStmt[2]]: | 146 # for t in [trueStmt[0], trueStmt[2]]: |
136 if isinstance(t, (URIRef, BNode)): | 147 # if isinstance(t, (URIRef, BNode)): |
137 candidateTermMatches[b].add(t) | 148 # candidateTermMatches[b].add(t) |
138 return candidateTermMatches | 149 return candidateTermMatches |
139 | 150 |
140 def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]: | 151 def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]: |
141 """if these stmts match otherwise, what BNode or Variable mappings do we learn? | 152 """if these stmts match otherwise, what BNode or Variable mappings do we learn? |
142 | 153 |
164 for val in vals: | 175 for val in vals: |
165 log.debug(f'{INDENT*5}{val!r}') | 176 log.debug(f'{INDENT*5}{val!r}') |
166 | 177 |
167 | 178 |
168 @dataclass | 179 @dataclass |
180 class CandidateTermMatches: | |
181 """lazily find the possible matches for this term""" | |
182 term: BindableTerm | |
183 lhs: Lhs | |
184 workingSet: Graph | |
185 boundSoFar: CandidateBinding | |
186 | |
187 def __post_init__(self): | |
188 self.results: List[Node] = [] # we have to be able to repeat the results | |
189 | |
190 res: Set[Node] = set() | |
191 for trueStmt in self.workingSet: # all bound | |
192 lStmts = list(self.lhsStmtsContainingTerm()) | |
193 log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}') | |
194 for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False): | |
195 log.debug(f'{INDENT*4} {pat=}') | |
196 implied = self._stmtImplies(pat, trueStmt) | |
197 if implied is not None: | |
198 res.add(implied) | |
199 self.results = list(res) | |
200 # self.results.sort() | |
201 | |
202 log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}') | |
203 | |
204 def _stmtImplies(self, pat: Triple, trueStmt: Triple) -> Optional[Node]: | |
205 """what value, if any, do we learn for our term from this LHS pattern statement and this known-true stmt""" | |
206 r = None | |
207 for p, t in zip(pat, trueStmt): | |
208 if isinstance(p, (Variable, BNode)): | |
209 if p != self.term: | |
210 # stmt is unbound in more than just our term | |
211 continue # unsure what to do - err on the side of too many bindings, since they get rechecked later | |
212 if r is None: | |
213 r = t | |
214 log.debug(f'{INDENT*4} implied term value {p=} {t=}') | |
215 elif r != t: | |
216 # (?x c ?x) matched with (a b c) doesn't work | |
217 return None | |
218 return r | |
219 | |
220 def lhsStmtsContainingTerm(self): | |
221 # lhs could precompute this | |
222 for lhsStmt in self.lhs.graph: | |
223 if self.term in lhsStmt: | |
224 yield lhsStmt | |
225 | |
226 def __iter__(self): | |
227 return iter(self.results) | |
228 | |
229 | |
230 @dataclass | |
169 class BoundLhs: | 231 class BoundLhs: |
170 lhs: Lhs | 232 lhs: Lhs |
171 binding: CandidateBinding | 233 binding: CandidateBinding |
172 | 234 |
173 def __post_init__(self): | 235 def __post_init__(self): |
238 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): | 300 def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): |
239 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): | 301 for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): |
240 log.debug(f'{INDENT*3} rule has a working binding:') | 302 log.debug(f'{INDENT*3} rule has a working binding:') |
241 | 303 |
242 for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): | 304 for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): |
243 log.debug(f'{INDENT*5} adding {lhsBoundStmt=}') | 305 log.debug(f'{INDENT*4} adding {lhsBoundStmt=}') |
244 workingSet.add(lhsBoundStmt) | 306 workingSet.add(lhsBoundStmt) |
245 for newStmt in bound.binding.apply(self.rhsGraph): | 307 for newStmt in bound.binding.apply(self.rhsGraph): |
246 log.debug(f'{INDENT*5} adding {newStmt=}') | 308 log.debug(f'{INDENT*4} adding {newStmt=}') |
247 workingSet.add(newStmt) | 309 workingSet.add(newStmt) |
248 implied.add(newStmt) | 310 implied.add(newStmt) |
249 | 311 |
250 | 312 |
251 class Inference: | 313 class Inference: |
277 | 339 |
278 bailout_iterations = 100 | 340 bailout_iterations = 100 |
279 delta = 1 | 341 delta = 1 |
280 stats['initWorkingSet'] = cast(int, workingSet.__len__()) | 342 stats['initWorkingSet'] = cast(int, workingSet.__len__()) |
281 while delta > 0 and bailout_iterations > 0: | 343 while delta > 0 and bailout_iterations > 0: |
344 log.debug('') | |
282 log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') | 345 log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') |
283 bailout_iterations -= 1 | 346 bailout_iterations -= 1 |
284 delta = -len(implied) | 347 delta = -len(implied) |
285 self._iterateAllRules(workingSet, implied, stats) | 348 self._iterateAllRules(workingSet, implied, stats) |
286 delta += len(implied) | 349 delta += len(implied) |