Mercurial > code > home > repos > homeauto
changeset 1593:b0df43d5494c
big rewrite- more classes, smaller methods, more typesafe, all current tests passing
author | drewp@bigasterisk.com |
---|---|
date | Sat, 04 Sep 2021 23:23:55 -0700 |
parents | d7b66234064b |
children | e58bcfa66093 |
files | service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py |
diffstat | 2 files changed, 277 insertions(+), 182 deletions(-) [+] |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Sat Sep 04 23:18:44 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Sat Sep 04 23:23:55 2021 -0700 @@ -17,6 +17,7 @@ from rdflib.term import Node, Variable log = logging.getLogger('infer') +INDENT = ' ' Triple = Tuple[Node, Node, Node] Rule = Tuple[Graph, Node, Graph] @@ -35,167 +36,277 @@ vars: Dict[Variable, Node] -inferredFuncs = { - ROOM['asFarenheit'], - MATH['sum'], -} +ReadOnlyWorkingSet = ReadOnlyGraphAggregate + filterFuncs = { MATH['greaterThan'], } -def withBinding(toBind: Graph, bindings: Dict[BindableTerm, Node], includeStaticStmts=True) -> Iterator[Triple]: - for stmt in toBind: - stmt = list(stmt) - static = True - for i, term in enumerate(stmt): - if isinstance(term, (Variable, BNode)): - stmt[i] = bindings[term] - static = False - else: - if includeStaticStmts or not static: +class CandidateBinding: + + def __init__(self, binding: Dict[BindableTerm, Node]): + self.binding = binding # mutable! + + def __repr__(self): + b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items())) + return f'CandidateBinding({b})' + + def apply(self, g: Graph) -> Iterator[Triple]: + for stmt in g: + stmt = list(stmt) + for i, term in enumerate(stmt): + if isinstance(term, (Variable, BNode)): + if term in self.binding: + stmt[i] = self.binding[term] + else: yield cast(Triple, stmt) + def applyFunctions(self, lhs): + """may grow the binding with some results""" + usedByFuncs = Graph() + while True: + before = len(self.binding) + delta = 0 + for ev in Evaluation.findEvals(lhs): + log.debug(f'{INDENT*3} found Evaluation') -def verifyBinding(lhs: Graph, binding: Dict[BindableTerm, Node], workingSet: Graph, usedByFuncs: Graph) -> bool: - """Can this lhs be true all at once in workingSet? Does it match with these bindings?""" - boundLhs = list(withBinding(lhs, binding)) - boundUsedByFuncs = list(withBinding(usedByFuncs, binding)) - if log.isEnabledFor(logging.DEBUG): - log.debug(f' verify all bindings against this lhs:') - for stmt in boundLhs: - log.debug(f' {stmt}') + newBindings, usedGraph = ev.resultBindings(self.binding) + usedByFuncs += usedGraph + for k, v in newBindings.items(): + if k in self.binding and self.binding[k] != v: + raise ValueError( + f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}') + self.binding[k] = v + delta = len(self.binding) - before + log.debug(f'{INDENT*4} rule {graphDump(usedGraph)} made {delta} new bindings') + if delta == 0: + break + return usedByFuncs - log.debug(f' and against this workingSet:') - for stmt in workingSet: - log.debug(f' {stmt}') + def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool: + """Can this lhs be true all at once in workingSet? Does it match with these bindings?""" + boundLhs = list(self.apply(lhs._g)) + boundUsedByFuncs = list(self.apply(usedByFuncs)) + + self.logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs) - log.debug(f' ignoring these usedByFuncs:') - for stmt in boundUsedByFuncs: - log.debug(f' {stmt}') - # The static stmts in lhs are obviously going - # to match- we only need to verify the ones - # that needed bindings. - for stmt in boundLhs: #withBinding(lhs, binding, includeStaticStmts=False): - log.debug(f' check for {stmt}') + for stmt in boundLhs: + log.debug(f'{INDENT*4} check for {stmt}') - if stmt[1] in filterFuncs: - if not mathTest(*stmt): - log.debug(f' binding was invalid because {stmt}) is not true') + if stmt[1] in filterFuncs: + if not mathTest(*stmt): + log.debug(f'{INDENT*5} binding was invalid because {stmt}) is not true') + return False + elif stmt in boundUsedByFuncs: + pass + elif stmt in workingSet: + pass + else: + log.debug(f'{INDENT*5} binding was invalid because {stmt}) is not known to be true') return False - elif stmt in boundUsedByFuncs: - pass - elif stmt in workingSet: - pass - else: - log.debug(f' binding was invalid because {stmt}) cannot be true') - return False - return True + log.debug(f"{INDENT*5} this rule's lhs can work under this binding") + return True + + def logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs): + if not log.isEnabledFor(logging.DEBUG): + return + log.debug(f'{INDENT*4}/ verify all bindings against this lhs:') + for stmt in sorted(boundLhs): + log.debug(f'{INDENT*4}|{INDENT} {stmt}') + + log.debug(f'{INDENT*4}| and against this workingSet:') + for stmt in sorted(workingSet): + log.debug(f'{INDENT*4}|{INDENT} {stmt}') + + log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:') + for stmt in sorted(boundUsedByFuncs): + log.debug(f'{INDENT*4}|{INDENT} {stmt}') + log.debug(f'{INDENT*4}\\') -def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[BindableTerm, Node]]: - """bindings that fit the LHS of a rule, using statements from workingSet and functions - from LHS""" - varsToBind: Set[BindableTerm] = set() - staticRuleStmts = Graph() - for ruleStmt in lhs: - varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))] - varsToBind.update(varsInStmt) - if (not varsInStmt # ok - #and not any(isinstance(t, BNode) for t in ruleStmt) # approx - ): - staticRuleStmts.add(ruleStmt) +class Lhs: + + def __init__(self, existingGraph): + self._g = existingGraph + + def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]: + """bindings that fit the LHS of a rule, using statements from workingSet and functions + from LHS""" + nodesToBind = self.nodesToBind() + log.debug(f'{INDENT*2} nodesToBind: {nodesToBind}') + + if not self.allStaticStatementsMatch(workingSet): + return + + candidateTermMatches: Dict[BindableTerm, Set[Node]] = self.allCandidateTermMatches(workingSet) + + # for n in nodesToBind: + # if n not in candidateTermMatches: + # candidateTermMatches[n] = set() + + orderedVars, orderedValueSets = organize(candidateTermMatches) + + self.logCandidates(orderedVars, orderedValueSets) - log.debug(f' varsToBind: {sorted(varsToBind)}') + log.debug(f'{INDENT*2} trying all permutations:') + + for perm in itertools.product(*orderedValueSets): + binding = CandidateBinding(dict(zip(orderedVars, perm))) + log.debug('') + log.debug(f'{INDENT*3}*trying {binding}') + + usedByFuncs = binding.applyFunctions(self) - if someStaticStmtDoesntMatch(staticRuleStmts, workingSet): - log.debug(f' someStaticStmtDoesntMatch: {graphDump(staticRuleStmts)}') - return + if not binding.verify(self, workingSet, usedByFuncs): + log.debug(f'{INDENT*3} this binding did not verify') + continue + yield binding - # the total set of terms each variable could possibly match - candidateTermMatches: Dict[BindableTerm, Set[Node]] = findCandidateTermMatches(lhs, workingSet) + def nodesToBind(self) -> List[BindableTerm]: + nodes: Set[BindableTerm] = set() + staticRuleStmts = Graph() + for ruleStmt in self._g: + varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))] + nodes.update(varsInStmt) + if (not varsInStmt # ok + #and not any(isinstance(t, BNode) for t in ruleStmt) # approx + ): + staticRuleStmts.add(ruleStmt) + return sorted(nodes) - orderedVars, orderedValueSets = organize(candidateTermMatches) + def allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool: + staticRuleStmts = Graph() + for ruleStmt in self._g: + varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))] + if (not varsInStmt # ok + #and not any(isinstance(t, BNode) for t in ruleStmt) # approx + ): + staticRuleStmts.add(ruleStmt) - log.debug(f' candidate terms:') - log.debug(f' {orderedVars=}') - log.debug(f' {orderedValueSets=}') + for ruleStmt in staticRuleStmts: + if ruleStmt not in workingSet: + log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') + return False + return True + + def allCandidateTermMatches(self, workingSet: ReadOnlyWorkingSet) -> Dict[BindableTerm, Set[Node]]: + """the total set of terms each variable could possibly match""" - for i, perm in enumerate(itertools.product(*orderedValueSets)): - binding: Dict[BindableTerm, Node] = dict(zip(orderedVars, perm)) - log.debug('') - log.debug(f' ** trying {binding=}') - usedByFuncs = Graph() - for v, val, used in inferredFuncBindings(lhs, binding): # loop this until it's done - log.debug(f' inferredFuncBindings tells us {v}={val}') - binding[v] = val - usedByFuncs += used - if len(binding) != len(varsToBind): - log.debug(f' binding is incomplete, needs {varsToBind}') + candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) + lhsBnodes: Set[BNode] = set() + for lhsStmt in self._g: + log.debug(f'{INDENT*3} possibles for this lhs stmt: {lhsStmt}') + for i, trueStmt in enumerate(sorted(workingSet)): + log.debug(f'{INDENT*4} consider this true stmt ({i}): {trueStmt}') + bindingsFromStatement: Dict[Variable, Set[Node]] = {} + for lhsTerm, trueTerm in zip(lhsStmt, trueStmt): + if isinstance(lhsTerm, BNode): + lhsBnodes.add(lhsTerm) + elif isinstance(lhsTerm, Variable): + bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm) + elif lhsTerm != trueTerm: + break + else: + for v, vals in bindingsFromStatement.items(): + candidateTermMatches[v].update(vals) - continue - if not verifyBinding(lhs, binding, workingSet, usedByFuncs): # fix this - log.debug(f' this binding did not verify') - continue - yield binding + for trueStmt in itertools.chain(workingSet, self._g): + for b in lhsBnodes: + for t in [trueStmt[0], trueStmt[2]]: + if isinstance(t, (URIRef, BNode)): + candidateTermMatches[b].add(t) + return candidateTermMatches + + def graphWithoutEvals(self, binding: CandidateBinding) -> Graph: + g = Graph() + usedByFuncs = binding.applyFunctions(self) + + for stmt in self._g: + if stmt not in usedByFuncs: + g.add(stmt) + return g + + def logCandidates(self, orderedVars, orderedValueSets): + if not log.isEnabledFor(logging.DEBUG): + return + log.debug(f'{INDENT*2} resulting candidate terms:') + for v, vals in zip(orderedVars, orderedValueSets): + log.debug(f'{INDENT*3} {v} could be:') + for val in vals: + log.debug(f'{INDENT*4}{val}') -def someStaticStmtDoesntMatch(staticRuleStmts, workingSet): - for ruleStmt in staticRuleStmts: - if ruleStmt not in workingSet: - log.debug(f' {ruleStmt} not in working set- skip rule') +class Evaluation: + """some lhs statements need to be evaluated with a special function + (e.g. math) and then not considered for the rest of the rule-firing + process. It's like they already 'matched' something, so they don't need + to match a statement from the known-true working set.""" + + @staticmethod + def findEvals(lhs: Lhs) -> Iterator['Evaluation']: + for stmt in lhs._g.triples((None, MATH['sum'], None)): + # shouldn't be redoing this here + operands, operandsStmts = parseList(lhs._g, stmt[0]) + g = Graph() + g += operandsStmts + yield Evaluation(operands, g, stmt) + + for stmt in lhs._g.triples((None, ROOM['asFarenheit'], None)): + g = Graph() + g.add(stmt) + yield Evaluation([stmt[0]], g, stmt) + + # internal, use findEvals + def __init__(self, operands: List[Node], operandsStmts: Graph, stmt: Triple) -> None: + self.operands = operands + self.operandsStmts = operandsStmts + self.stmt = stmt - return True - return False + def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]: + """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?""" + pred = self.stmt[1] + objVar = self.stmt[2] + boundOperands = [] + for o in self.operands: + if isinstance(o, Variable): + try: + o = inputBindings[o] + except KeyError: + return {}, self.operandsStmts + + boundOperands.append(o) + + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + + if pred == MATH['sum']: + log.debug(f'{INDENT*4} sum {list(map(self.numericNode, boundOperands))}') + obj = cast(Literal, Literal(sum(map(self.numericNode, boundOperands)))) + self.operandsStmts.add(self.stmt) + return {objVar: obj}, self.operandsStmts + elif pred == ROOM['asFarenheit']: + if len(boundOperands) != 1: + raise ValueError(":asFarenheit takes 1 subject operand") + f = Literal(Decimal(self.numericNode(boundOperands[0])) * 9 / 5 + 32) + g = Graph() + g.add(self.stmt) + + log.debug('made 1 st graph') + return {objVar: f}, g + else: + raise NotImplementedError() + + def numericNode(self, n: Node): + if not isinstance(n, Literal): + raise TypeError(f'expected Literal, got {n=}') + val = n.toPython() + if not isinstance(val, (int, float, Decimal)): + raise TypeError(f'expected number, got {val=}') + return val -def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[BindableTerm, Set[Node]]: - candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) - lhsBnodes: Set[BNode] = set() - for lhsStmt in lhs: - for trueStmt in workingSet: - log.debug(f' lhsStmt={graphDump([lhsStmt])} trueStmt={graphDump([trueStmt])}') - bindingsFromStatement: Dict[Variable, Set[Node]] = {} - for lhsTerm, trueTerm in zip(lhsStmt, trueStmt): - # log.debug(f' test {lhsTerm=} {trueTerm=}') - if isinstance(lhsTerm, BNode): - lhsBnodes.add(lhsTerm) - elif isinstance(lhsTerm, Variable): - bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm) - elif lhsTerm != trueTerm: - break - else: - for v, vals in bindingsFromStatement.items(): - candidateTermMatches[v].update(vals) - - for trueStmt in itertools.chain(workingSet, lhs): - for b in lhsBnodes: - for t in [trueStmt[0], trueStmt[2]]: - if isinstance(t, (URIRef, BNode)): - candidateTermMatches[b].add(t) - return candidateTermMatches - - -def inferredFuncObject(subj, pred, graph, bindings) -> Tuple[Literal, Graph]: - """return result from like `(1 2) math:sum ?out .` plus a graph of all the - statements involved in that function rule (including the bound answer""" - used = Graph() - if pred == ROOM['asFarenheit']: - obj = Literal(Decimal(subj.toPython()) * 9 / 5 + 32) - elif pred == MATH['sum']: - operands, operandsStmts = parseList(graph, subj) - # shouldn't be redoing this here - operands = [bindings[o] if isinstance(o, Variable) else o for o in operands] - log.debug(f' sum {[op.toPython() for op in operands]}') - used += operandsStmts - obj = Literal(sum(op.toPython() for op in operands)) - else: - raise NotImplementedError(pred) - - used.add((subj, pred, obj)) - return obj, used - - +# merge into evaluation, raising a Invalid for impossible stmts def mathTest(subj, pred, obj): x = subj.toPython() y = obj.toPython() @@ -217,10 +328,11 @@ """ returns new graph of inferred statements. """ - log.info(f'Begin inference of graph len={len(graph)} with rules len={len(self.rules)}:') + log.debug(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:') # everything that is true: the input graph, plus every rule conclusion we can make - workingSet = graphCopy(graph) + workingSet = Graph() + workingSet += graph # just the statements that came from rule RHS's. implied = ConjunctiveGraph() @@ -228,52 +340,47 @@ bailout_iterations = 100 delta = 1 while delta > 0 and bailout_iterations > 0: - log.debug(f' * iteration ({bailout_iterations} left)') + log.debug(f'{INDENT*1}*iteration ({bailout_iterations} left)') bailout_iterations -= 1 delta = -len(implied) self._iterateAllRules(workingSet, implied) delta += len(implied) - log.info(f' this inference round added {delta} more implied stmts') - log.info(f' {len(implied)} stmts implied:') + log.info(f'{INDENT*1} this inference round added {delta} more implied stmts') + log.info(f'{INDENT*0} {len(implied)} stmts implied:') for st in implied: - log.info(f' {st}') + log.info(f'{INDENT*2} {st}') return implied - def _iterateAllRules(self, workingSet, implied): + def _iterateAllRules(self, workingSet: Graph, implied: Graph): for i, r in enumerate(self.rules): - log.debug(f' workingSet: {graphDump(workingSet)}') - log.debug(f' - applying rule {i}') - log.debug(f' lhs: {graphDump(r[0])}') - log.debug(f' rhs: {graphDump(r[2])}') + log.debug('') + log.debug(f'{INDENT*2} workingSet:') + for i, stmt in enumerate(sorted(workingSet)): + log.debug(f'{INDENT*3} ({i}) {stmt}') + + log.debug('') + log.debug(f'{INDENT*2}-applying rule {i}') + log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}') + log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}') if r[1] == LOG['implies']: - applyRule(r[0], r[2], workingSet, implied) + applyRule(Lhs(r[0]), r[2], workingSet, implied) else: - log.info(f' {r} not a rule?') + log.info(f'{INDENT*2} {r} not a rule?') -def graphCopy(src: Graph) -> Graph: - if isinstance(src, ConjunctiveGraph): - out = ConjunctiveGraph() - out.addN(src.quads()) - return out - else: - out = Graph() - for triple in src: - out.add(triple) - return out - - -def applyRule(lhs: Graph, rhs: Graph, workingSet: Graph, implied: Graph): - for bindings in findCandidateBindings(lhs, workingSet): - log.debug(f' rule gave {bindings=}') - for lhsBoundStmt in withBinding(lhs, bindings): +def applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph): + for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])): + # log.debug(f' rule gave {binding=}') + for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)): workingSet.add(lhsBoundStmt) - for newStmt in withBinding(rhs, bindings): + for newStmt in binding.apply(rhs): workingSet.add(newStmt) implied.add(newStmt) def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]: + """"Do like Collection(g, subj) but also return all the + triples that are involved in the list""" out = [] used = set() cur = subj @@ -284,6 +391,7 @@ next = graph.value(cur, RDF.rest) used.add((cur, RDF.rest, next)) + cur = next if cur == RDF.nil: break @@ -317,23 +425,6 @@ return orderedVars, orderedValueSets -def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node, Graph]]: - for stmt in lhs: - if stmt[1] not in inferredFuncs: - continue - var = stmt[2] - if not isinstance(var, Variable): - continue - - x = stmt[0] - if isinstance(x, Variable): - x = bindingsBefore[x] - - resultObject, usedByFunc = inferredFuncObject(x, stmt[1], lhs, bindingsBefore) - - yield var, resultObject, usedByFunc - - def isStatic(spo: Triple): for t in spo: if isinstance(t, (Variable, BNode)):
--- a/service/mqtt_to_rdf/inference_test.py Sat Sep 04 23:18:44 2021 -0700 +++ b/service/mqtt_to_rdf/inference_test.py Sat Sep 04 23:23:55 2021 -0700 @@ -201,6 +201,10 @@ self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3("")) self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 .")) + def testNonFiringMathRule(self): + inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3("")), N3("")) + def testStatementGeneratingRule(self): inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))