Mercurial > code > home > repos > homeauto
changeset 1589:5c1055be3c36
WIP more debugging, working towards bnode-matching support
author | drewp@bigasterisk.com |
---|---|
date | Thu, 02 Sep 2021 13:39:27 -0700 |
parents | 0757fafbfdab |
children | 327202020892 |
files | service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py |
diffstat | 2 files changed, 80 insertions(+), 35 deletions(-) [+] |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Thu Sep 02 01:58:31 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Thu Sep 02 13:39:27 2021 -0700 @@ -6,19 +6,20 @@ import logging from dataclasses import dataclass from decimal import Decimal -from typing import Dict, Iterator, List, Set, Tuple, cast +from typing import Dict, Iterator, List, Set, Tuple, Union, cast from urllib.request import OpenerDirector from prometheus_client import Summary from rdflib import BNode, Graph, Literal, Namespace from rdflib.collection import Collection -from rdflib.graph import ConjunctiveGraph +from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate from rdflib.term import Node, Variable log = logging.getLogger('infer') Triple = Tuple[Node, Node, Node] Rule = Tuple[Graph, Node, Graph] +BindableTerm = Union[Variable, BNode] READ_RULES_CALLS = Summary('read_rules_calls', 'calls') @@ -47,57 +48,85 @@ """ log.info(f'Begin inference of graph len={len(graph)} with rules len={len(self.rules)}:') - workingSet = ConjunctiveGraph() - if isinstance(graph, ConjunctiveGraph): - workingSet.addN(graph.quads()) - else: - for triple in graph: - workingSet.add(triple) + # everything that is true: the input graph, plus every rule conclusion we can make + workingSet = graphCopy(graph) + # just the statements that came from rule RHS's. implied = ConjunctiveGraph() bailout_iterations = 100 delta = 1 while delta > 0 and bailout_iterations > 0: + log.debug(f' * iteration ({bailout_iterations} left)') bailout_iterations -= 1 delta = -len(implied) self._iterateAllRules(workingSet, implied) delta += len(implied) log.info(f' this inference round added {delta} more implied stmts') - log.info(f'{len(implied)} stmts implied:') + log.info(f' {len(implied)} stmts implied:') for st in implied: - log.info(f' {st}') + log.info(f' {st}') return implied def _iterateAllRules(self, workingSet, implied): - for r in self.rules: + for i, r in enumerate(self.rules): + log.debug(f' workingSet: {graphDump(workingSet)}') + log.debug(f' - applying rule {i}') + log.debug(f' lhs: {graphDump(r[0])}') + log.debug(f' rhs: {graphDump(r[2])}') if r[1] == LOG['implies']: applyRule(r[0], r[2], workingSet, implied) else: - log.info(f' {r} not a rule?') + log.info(f' {r} not a rule?') -def applyRule(lhs: Graph, rhs: Graph, workingSet, implied): +def graphCopy(src: Graph) -> Graph: + if isinstance(src, ConjunctiveGraph): + out = ConjunctiveGraph() + out.addN(src.quads()) + return out + else: + out = Graph() + for triple in src: + out.add(triple) + return out + + +def graphDump(g: Graph): + g.bind('', ROOM) + g.bind('ex', Namespace('http://example.com/')) + lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() + lines = [line for line in lines if not line.startswith('@prefix')] + return ' '.join(lines) + + +def applyRule(lhs: Graph, rhs: Graph, workingSet: Graph, implied: Graph): for bindings in findCandidateBindings(lhs, workingSet): - log.debug(f' - rule gave {bindings=}') + log.debug(f' rule gave {bindings=}') + for lhsBoundStmt in withBinding(lhs, bindings): + workingSet.add(lhsBoundStmt) for newStmt in withBinding(rhs, bindings): workingSet.add(newStmt) implied.add(newStmt) def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]: + """bindings that fit the LHS of a rule, using statements from workingSet and functions + from LHS""" varsToBind: Set[Variable] = set() - staticRuleStmts = [] + staticRuleStmts = Graph() for ruleStmt in lhs: varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)] varsToBind.update(varsInStmt) if (not varsInStmt # ok - and not any(isinstance(t, BNode) for t in ruleStmt) # approx + #and not any(isinstance(t, BNode) for t in ruleStmt) # approx ): - staticRuleStmts.append(ruleStmt) + staticRuleStmts.add(ruleStmt) + + log.debug(f' {varsToBind=}') if someStaticStmtDoesntMatch(staticRuleStmts, workingSet): - log.debug('static shortcircuit') + log.debug(f' someStaticStmtDoesntMatch: {graphDump(staticRuleStmts)}') return # the total set of terms each variable could possibly match @@ -105,17 +134,18 @@ orderedVars, orderedValueSets = organize(candidateTermMatches) - log.debug(f' {orderedVars=}') - log.debug(f'{orderedValueSets=}') + log.debug(f' candidate terms:') + log.debug(f' {orderedVars=}') + log.debug(f' {orderedValueSets=}') for perm in itertools.product(*orderedValueSets): binding: Dict[Variable, Node] = dict(zip(orderedVars, perm)) - log.debug(f'{binding=} but lets look for funcs') + log.debug(f' {binding=} but lets look for funcs') for v, val in inferredFuncBindings(lhs, binding): # loop this until it's done - log.debug(f'ifb tells us {v}={val}') + log.debug(f' ifb tells us {v}={val}') binding[v] = val if not verifyBinding(lhs, binding, workingSet): # fix this - log.debug(f'verify culls') + log.debug(f' verify culls') continue yield binding @@ -136,13 +166,15 @@ def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]: candidateTermMatches: Dict[Variable, Set[Node]] = {} - for r in lhs: - for w in workingSet: + for lhsStmt in lhs: + for trueStmt in workingSet: + log.debug(f'{lhsStmt=} {trueStmt=}') bindingsFromStatement: Dict[Variable, Set[Node]] = {} - for rterm, wterm in zip(r, w): - if isinstance(rterm, Variable): - bindingsFromStatement.setdefault(rterm, set()).add(wterm) - elif rterm != wterm: + for lhsTerm, trueTerm in zip(lhsStmt, trueStmt): + log.debug(f' test {lhsTerm=} {trueTerm=}') + if isinstance(lhsTerm, Variable): + bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm) + elif lhsTerm != trueTerm: break else: for v, vals in bindingsFromStatement.items(): @@ -165,13 +197,17 @@ def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool: + """can this lhs be true all at once?""" for stmt in withBinding(lhs, binding): - log.debug(f'lhs verify {stmt}') + log.debug(f' lhs verify {stmt}') if stmt[1] in filterFuncs: if not mathTest(*stmt): return False - elif stmt not in workingSet and stmt[1] not in inferredFuncs: - log.debug(f' ver culls here') + elif (stmt not in workingSet # not previously true + and stmt not in lhs # not from the bindings in this rule + and stmt[1] not in inferredFuncs # not a function stmt (maybe this is wrong) + ): + log.debug(f' ver culls here') return False return True @@ -225,5 +261,7 @@ def someStaticStmtDoesntMatch(staticRuleStmts, workingSet): for ruleStmt in staticRuleStmts: if ruleStmt not in workingSet: + log.debug(f' {ruleStmt} not in working set- skip rule') + return True return False
--- a/service/mqtt_to_rdf/inference_test.py Thu Sep 02 01:58:31 2021 -0700 +++ b/service/mqtt_to_rdf/inference_test.py Thu Sep 02 13:39:27 2021 -0700 @@ -116,6 +116,13 @@ self.assertGraphEqual(implied, N3(":new :stmt :c .")) +class TestBnodeMatching(WithGraphEqual): + def test1(self): + inf = makeInferenceWithRules("{ [ :a :b ] . } => { :new :stmt :here } .") + implied = inf.infer(N3("[ :a :b ] .")) + self.assertGraphEqual(implied, N3(":new :stmt :here .")) + + class TestInferenceWithMathFunctions(WithGraphEqual): def testBoolFilter(self): @@ -124,9 +131,9 @@ self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3("")) self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 .")) - def testStatementGeneratingRule(self): - inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") - self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) + # def testStatementGeneratingRule(self): + # inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") + # self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) class TestInferenceWithCustomFunctions(WithGraphEqual):