Mercurial > code > home > repos > homeauto
changeset 1590:327202020892
WIP inference- getting into more degenerate test cases
author | drewp@bigasterisk.com |
---|---|
date | Thu, 02 Sep 2021 23:20:55 -0700 |
parents | 5c1055be3c36 |
children | 668958454ae2 |
files | service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py |
diffstat | 2 files changed, 204 insertions(+), 57 deletions(-) [+] |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Thu Sep 02 13:39:27 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Thu Sep 02 23:20:55 2021 -0700 @@ -2,6 +2,7 @@ copied from reasoning 2021-08-29. probably same api. should be able to lib/ this out """ +from collections import defaultdict import itertools import logging from dataclasses import dataclass @@ -10,7 +11,7 @@ from urllib.request import OpenerDirector from prometheus_client import Summary -from rdflib import BNode, Graph, Literal, Namespace +from rdflib import BNode, Graph, Literal, Namespace, URIRef, RDF from rdflib.collection import Collection from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate from rdflib.term import Node, Variable @@ -92,7 +93,12 @@ return out -def graphDump(g: Graph): +def graphDump(g: Union[Graph, List[Triple]]): + if not isinstance(g, Graph): + g2 = Graph() + for stmt in g: + g2.add(stmt) + g = g2 g.bind('', ROOM) g.bind('ex', Namespace('http://example.com/')) lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines() @@ -110,27 +116,27 @@ implied.add(newStmt) -def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]: +def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[BindableTerm, Node]]: """bindings that fit the LHS of a rule, using statements from workingSet and functions from LHS""" - varsToBind: Set[Variable] = set() + varsToBind: Set[BindableTerm] = set() staticRuleStmts = Graph() for ruleStmt in lhs: - varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)] + varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))] varsToBind.update(varsInStmt) if (not varsInStmt # ok #and not any(isinstance(t, BNode) for t in ruleStmt) # approx ): staticRuleStmts.add(ruleStmt) - log.debug(f' {varsToBind=}') + log.debug(f' varsToBind: {sorted(varsToBind)}') if someStaticStmtDoesntMatch(staticRuleStmts, workingSet): log.debug(f' someStaticStmtDoesntMatch: {graphDump(staticRuleStmts)}') return # the total set of terms each variable could possibly match - candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet) + candidateTermMatches: Dict[BindableTerm, Set[Node]] = findCandidateTermMatches(lhs, workingSet) orderedVars, orderedValueSets = organize(candidateTermMatches) @@ -138,76 +144,113 @@ log.debug(f' {orderedVars=}') log.debug(f' {orderedValueSets=}') - for perm in itertools.product(*orderedValueSets): - binding: Dict[Variable, Node] = dict(zip(orderedVars, perm)) - log.debug(f' {binding=} but lets look for funcs') - for v, val in inferredFuncBindings(lhs, binding): # loop this until it's done - log.debug(f' ifb tells us {v}={val}') + for i, perm in enumerate(itertools.product(*orderedValueSets)): + binding: Dict[BindableTerm, Node] = dict(zip(orderedVars, perm)) + log.debug('') + log.debug(f' ** trying {binding=}') + usedByFuncs = Graph() + for v, val, used in inferredFuncBindings(lhs, binding): # loop this until it's done + log.debug(f' inferredFuncBindings tells us {v}={val}') binding[v] = val - if not verifyBinding(lhs, binding, workingSet): # fix this - log.debug(f' verify culls') + usedByFuncs += used + if len(binding) != len(varsToBind): + log.debug(f' binding is incomplete, needs {varsToBind}') + + continue + if not verifyBinding(lhs, binding, workingSet, usedByFuncs): # fix this + log.debug(f' this binding did not verify') continue yield binding -def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]: +def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node, Graph]]: for stmt in lhs: if stmt[1] not in inferredFuncs: continue - if not isinstance(stmt[2], Variable): + var = stmt[2] + if not isinstance(var, Variable): continue x = stmt[0] if isinstance(x, Variable): x = bindingsBefore[x] - yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore) + + resultObject, usedByFunc = inferredFuncObject(x, stmt[1], lhs, bindingsBefore) + + yield var, resultObject, usedByFunc -def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]: - candidateTermMatches: Dict[Variable, Set[Node]] = {} - +def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[BindableTerm, Set[Node]]: + candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) + lhsBnodes: Set[BNode] = set() for lhsStmt in lhs: for trueStmt in workingSet: - log.debug(f'{lhsStmt=} {trueStmt=}') + log.debug(f' lhsStmt={graphDump([lhsStmt])} trueStmt={graphDump([trueStmt])}') bindingsFromStatement: Dict[Variable, Set[Node]] = {} for lhsTerm, trueTerm in zip(lhsStmt, trueStmt): - log.debug(f' test {lhsTerm=} {trueTerm=}') - if isinstance(lhsTerm, Variable): + # log.debug(f' test {lhsTerm=} {trueTerm=}') + if isinstance(lhsTerm, BNode): + lhsBnodes.add(lhsTerm) + elif isinstance(lhsTerm, Variable): bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm) elif lhsTerm != trueTerm: break else: for v, vals in bindingsFromStatement.items(): - candidateTermMatches.setdefault(v, set()).update(vals) + candidateTermMatches[v].update(vals) + + for trueStmt in itertools.chain(workingSet, lhs): + for b in lhsBnodes: + for t in [trueStmt[0], trueStmt[2]]: + if isinstance(t, (URIRef, BNode)): + candidateTermMatches[b].add(t) return candidateTermMatches -def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]: - for stmt in rhs: +def withBinding(toBind: Graph, bindings: Dict[BindableTerm, Node], includeStaticStmts=True) -> Iterator[Triple]: + for stmt in toBind: stmt = list(stmt) - for i, t in enumerate(stmt): - if isinstance(t, Variable): - try: - stmt[i] = bindings[t] - except KeyError: - # stmt is from another rule that we're not applying right now - break + static = True + for i, term in enumerate(stmt): + if isinstance(term, (Variable, BNode)): + stmt[i] = bindings[term] + static = False else: - yield cast(Triple, stmt) + if includeStaticStmts or not static: + yield cast(Triple, stmt) -def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool: - """can this lhs be true all at once?""" - for stmt in withBinding(lhs, binding): - log.debug(f' lhs verify {stmt}') +def verifyBinding(lhs: Graph, binding: Dict[BindableTerm, Node], workingSet: Graph, usedByFuncs: Graph) -> bool: + """Can this lhs be true all at once in workingSet? Does it match with these bindings?""" + log.debug(f' verify all bindings against this lhs:') + boundLhs = list(withBinding(lhs, binding)) + for stmt in boundLhs: + log.debug(f' {stmt}') + + log.debug(f' and against this workingSet:') + for stmt in workingSet: + log.debug(f' {stmt}') + + log.debug(f' ignoring these usedByFuncs:') + boundUsedByFuncs = list(withBinding(usedByFuncs, binding)) + for stmt in boundUsedByFuncs: + log.debug(f' {stmt}') + # The static stmts in lhs are obviously going + # to match- we only need to verify the ones + # that needed bindings. + for stmt in boundLhs: #withBinding(lhs, binding, includeStaticStmts=False): + log.debug(f' check for {stmt}') + if stmt[1] in filterFuncs: if not mathTest(*stmt): + log.debug(f' binding was invalid because {stmt}) is not true') return False - elif (stmt not in workingSet # not previously true - and stmt not in lhs # not from the bindings in this rule - and stmt[1] not in inferredFuncs # not a function stmt (maybe this is wrong) - ): - log.debug(f' ver culls here') + elif stmt in boundUsedByFuncs: + pass + elif stmt in workingSet: + pass + else: + log.debug(f' binding was invalid because {stmt}) cannot be true') return False return True @@ -221,19 +264,49 @@ } -def inferredFuncObject(subj, pred, graph, bindings): +def isStatic(spo: Triple): + for t in spo: + if isinstance(t, (Variable, BNode)): + return False + return True + + +def inferredFuncObject(subj, pred, graph, bindings) -> Tuple[Literal, Graph]: + """return result from like `(1 2) math:sum ?out .` plus a graph of all the + statements involved in that function rule (including the bound answer""" + used = Graph() if pred == ROOM['asFarenheit']: - return Literal(Decimal(subj.toPython()) * 9 / 5 + 32) + obj = Literal(Decimal(subj.toPython()) * 9 / 5 + 32) elif pred == MATH['sum']: - operands = Collection(graph, subj) + operands, operandsStmts = parseList(graph, subj) # shouldn't be redoing this here operands = [bindings[o] if isinstance(o, Variable) else o for o in operands] - log.debug(f' sum {list(operands)}') - return Literal(sum(op.toPython() for op in operands)) - + log.debug(f' sum {[op.toPython() for op in operands]}') + used += operandsStmts + obj = Literal(sum(op.toPython() for op in operands)) else: raise NotImplementedError(pred) + used.add((subj, pred, obj)) + return obj, used + + +def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]: + out = [] + used = set() + cur = subj + while True: + # bug: mishandles empty list + out.append(graph.value(cur, RDF.first)) + used.add((cur, RDF.first, out[-1])) + + next = graph.value(cur, RDF.rest) + used.add((cur, RDF.rest, next)) + cur = next + if cur == RDF.nil: + break + return out, used + def mathTest(subj, pred, obj): x = subj.toPython() @@ -244,10 +317,10 @@ raise NotImplementedError(pred) -def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]: +def organize(candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Tuple[List[BindableTerm], List[List[Node]]]: items = list(candidateTermMatches.items()) items.sort() - orderedVars: List[Variable] = [] + orderedVars: List[BindableTerm] = [] orderedValueSets: List[List[Node]] = [] for v, vals in items: orderedVars.append(v)
--- a/service/mqtt_to_rdf/inference_test.py Thu Sep 02 13:39:27 2021 -0700 +++ b/service/mqtt_to_rdf/inference_test.py Thu Sep 02 23:20:55 2021 -0700 @@ -2,12 +2,64 @@ also see https://github.com/w3c/N3/tree/master/tests/N3Tests """ import unittest - -from rdflib import ConjunctiveGraph, Namespace, Graph +import itertools +from rdflib import ConjunctiveGraph, Namespace, Graph, BNode from rdflib.parser import StringInputSource from inference import Inference + +def patchSlimReprs(): + import rdflib.term + + def ur(self): + clsName = "U" if self.__class__ is rdflib.term.URIRef else self.__class__.__name__ + return """%s(%s)""" % (clsName, super(rdflib.term.URIRef, self).__repr__()) + + rdflib.term.URIRef.__repr__ = ur + + def br(self): + clsName = "BNode" if self.__class__ is rdflib.term.BNode else self.__class__.__name__ + return """%s(%s)""" % (clsName, super(rdflib.term.BNode, self).__repr__()) + + rdflib.term.BNode.__repr__ = br + + def vr(self): + clsName = "V" if self.__class__ is rdflib.term.Variable else self.__class__.__name__ + return """%s(%s)""" % (clsName, super(rdflib.term.Variable, self).__repr__()) + + rdflib.term.Variable.__repr__ = vr + + +patchSlimReprs() + + +def patchBnodeCounter(): + import rdflib.term + serial = itertools.count() + + def n(cls, value=None, _sn_gen='', _prefix='') -> BNode: + if value is None: + value = 'N-%s' % next(serial) + return rdflib.term.Identifier.__new__(cls, value) + + rdflib.term.BNode.__new__ = n + + import rdflib.plugins.parsers.notation3 + + def newBlankNode(self, uri=None, why=None): + if uri is None: + self.counter += 1 + bn = BNode('f-%s-%s' % (self.number, self.counter)) + else: + bn = BNode(uri.split('#').pop().replace('_', 'b')) + return bn + + rdflib.plugins.parsers.notation3.Formula.newBlankNode = newBlankNode + + +patchBnodeCounter() + ROOM = Namespace('http://projects.bigasterisk.com/room/') @@ -117,11 +169,29 @@ class TestBnodeMatching(WithGraphEqual): - def test1(self): + + def testRuleBnodeBindsToInputBnode(self): inf = makeInferenceWithRules("{ [ :a :b ] . } => { :new :stmt :here } .") implied = inf.infer(N3("[ :a :b ] .")) self.assertGraphEqual(implied, N3(":new :stmt :here .")) + def testRuleVarBindsToInputBNode(self): + inf = makeInferenceWithRules("{ ?z :a :b . } => { :new :stmt :here } .") + implied = inf.infer(N3("[] :a :b .")) + self.assertGraphEqual(implied, N3(":new :stmt :here .")) + + +class TestSelfFulfillingRule(WithGraphEqual): + + def test1(self): + inf = makeInferenceWithRules("{ } => { :new :stmt :x } .") + self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt :x .")) + self.assertGraphEqual(inf.infer(N3(":any :any :any .")), N3(":new :stmt :x .")) + + def test2(self): + inf = makeInferenceWithRules("{ (2) math:sum ?x } => { :new :stmt ?x } .") + self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt 2 .")) + class TestInferenceWithMathFunctions(WithGraphEqual): @@ -131,9 +201,13 @@ self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3("")) self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 .")) - # def testStatementGeneratingRule(self): - # inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") - # self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) + def testStatementGeneratingRule(self): + inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 .")) + + def test3Operands(self): + inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .") + self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 .")) class TestInferenceWithCustomFunctions(WithGraphEqual):