Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1594:e58bcfa66093
cleanups and a few fixed cases
author | drewp@bigasterisk.com |
---|---|
date | Sun, 05 Sep 2021 01:15:55 -0700 |
parents | b0df43d5494c |
children | 4e795ed3a693 |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Sat Sep 04 23:23:55 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Sun Sep 05 01:15:55 2021 -0700 @@ -2,17 +2,15 @@ copied from reasoning 2021-08-29. probably same api. should be able to lib/ this out """ -from collections import defaultdict import itertools import logging -from dataclasses import dataclass +from collections import defaultdict +from dataclasses import dataclass, field from decimal import Decimal from typing import Dict, Iterator, List, Set, Tuple, Union, cast -from urllib.request import OpenerDirector from prometheus_client import Summary -from rdflib import BNode, Graph, Literal, Namespace, URIRef, RDF -from rdflib.collection import Collection +from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate from rdflib.term import Node, Variable @@ -22,6 +20,7 @@ Triple = Tuple[Node, Node, Node] Rule = Tuple[Graph, Node, Graph] BindableTerm = Union[Variable, BNode] +ReadOnlyWorkingSet = ReadOnlyGraphAggregate READ_RULES_CALLS = Summary('read_rules_calls', 'calls') @@ -30,23 +29,13 @@ MATH = Namespace('http://www.w3.org/2000/10/swap/math#') -@dataclass -class _RuleMatch: - """one way that a rule can match the working set""" - vars: Dict[Variable, Node] +class EvaluationFailed(ValueError): + """e.g. we were given (5 math:greaterThan 6)""" -ReadOnlyWorkingSet = ReadOnlyGraphAggregate - -filterFuncs = { - MATH['greaterThan'], -} - - +@dataclass class CandidateBinding: - - def __init__(self, binding: Dict[BindableTerm, Node]): - self.binding = binding # mutable! + binding: Dict[BindableTerm, Node] def __repr__(self): b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items())) @@ -54,39 +43,48 @@ def apply(self, g: Graph) -> Iterator[Triple]: for stmt in g: - stmt = list(stmt) - for i, term in enumerate(stmt): - if isinstance(term, (Variable, BNode)): - if term in self.binding: - stmt[i] = self.binding[term] - else: - yield cast(Triple, stmt) + yield (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2])) - def applyFunctions(self, lhs): + def _applyTerm(self, term: Node): + if isinstance(term, (Variable, BNode)): + if term in self.binding: + return self.binding[term] + return term + + def applyFunctions(self, lhs) -> Graph: """may grow the binding with some results""" usedByFuncs = Graph() while True: - before = len(self.binding) - delta = 0 - for ev in Evaluation.findEvals(lhs): - log.debug(f'{INDENT*3} found Evaluation') - - newBindings, usedGraph = ev.resultBindings(self.binding) - usedByFuncs += usedGraph - for k, v in newBindings.items(): - if k in self.binding and self.binding[k] != v: - raise ValueError( - f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}') - self.binding[k] = v - delta = len(self.binding) - before - log.debug(f'{INDENT*4} rule {graphDump(usedGraph)} made {delta} new bindings') + delta = self._applyFunctionsIteration(lhs, usedByFuncs) if delta == 0: break return usedByFuncs + def _applyFunctionsIteration(self, lhs, usedByFuncs: Graph): + before = len(self.binding) + delta = 0 + for ev in Evaluation.findEvals(lhs): + log.debug(f'{INDENT*3} found Evaluation') + + newBindings, usedGraph = ev.resultBindings(self.binding) + usedByFuncs += usedGraph + self._addNewBindings(newBindings) + delta = len(self.binding) - before + dump = "(...)" + if log.isEnabledFor(logging.DEBUG) and cast(int, usedGraph.__len__()) < 20: + dump = graphDump(usedGraph) + log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings') + return delta + + def _addNewBindings(self, newBindings): + for k, v in newBindings.items(): + if k in self.binding and self.binding[k] != v: + raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}') + self.binding[k] = v + def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool: """Can this lhs be true all at once in workingSet? Does it match with these bindings?""" - boundLhs = list(self.apply(lhs._g)) + boundLhs = list(self.apply(lhs.graph)) boundUsedByFuncs = list(self.apply(usedByFuncs)) self.logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs) @@ -94,11 +92,7 @@ for stmt in boundLhs: log.debug(f'{INDENT*4} check for {stmt}') - if stmt[1] in filterFuncs: - if not mathTest(*stmt): - log.debug(f'{INDENT*5} binding was invalid because {stmt}) is not true') - return False - elif stmt in boundUsedByFuncs: + if stmt in boundUsedByFuncs: pass elif stmt in workingSet: pass @@ -125,26 +119,32 @@ log.debug(f'{INDENT*4}\\') +@dataclass class Lhs: + graph: Graph + + staticRuleStmts: Graph = field(default_factory=Graph) + lhsBindables: Set[BindableTerm] = field(default_factory=set) + lhsBnodes: Set[BNode] = field(default_factory=set) - def __init__(self, existingGraph): - self._g = existingGraph + def __post_init__(self): + for ruleStmt in self.graph: + varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] + self.lhsBindables.update(varsAndBnodesInStmt) + self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) + if not varsAndBnodesInStmt: + self.staticRuleStmts.add(ruleStmt) def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]: """bindings that fit the LHS of a rule, using statements from workingSet and functions from LHS""" - nodesToBind = self.nodesToBind() - log.debug(f'{INDENT*2} nodesToBind: {nodesToBind}') + log.debug(f'{INDENT*2} nodesToBind: {self.lhsBindables}') if not self.allStaticStatementsMatch(workingSet): return candidateTermMatches: Dict[BindableTerm, Set[Node]] = self.allCandidateTermMatches(workingSet) - # for n in nodesToBind: - # if n not in candidateTermMatches: - # candidateTermMatches[n] = set() - orderedVars, orderedValueSets = organize(candidateTermMatches) self.logCandidates(orderedVars, orderedValueSets) @@ -156,35 +156,18 @@ log.debug('') log.debug(f'{INDENT*3}*trying {binding}') - usedByFuncs = binding.applyFunctions(self) + try: + usedByFuncs = binding.applyFunctions(self) + except EvaluationFailed: + continue if not binding.verify(self, workingSet, usedByFuncs): log.debug(f'{INDENT*3} this binding did not verify') continue yield binding - def nodesToBind(self) -> List[BindableTerm]: - nodes: Set[BindableTerm] = set() - staticRuleStmts = Graph() - for ruleStmt in self._g: - varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))] - nodes.update(varsInStmt) - if (not varsInStmt # ok - #and not any(isinstance(t, BNode) for t in ruleStmt) # approx - ): - staticRuleStmts.add(ruleStmt) - return sorted(nodes) - def allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool: - staticRuleStmts = Graph() - for ruleStmt in self._g: - varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))] - if (not varsInStmt # ok - #and not any(isinstance(t, BNode) for t in ruleStmt) # approx - ): - staticRuleStmts.add(ruleStmt) - - for ruleStmt in staticRuleStmts: + for ruleStmt in self.staticRuleStmts: if ruleStmt not in workingSet: log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule') return False @@ -194,35 +177,43 @@ """the total set of terms each variable could possibly match""" candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set) - lhsBnodes: Set[BNode] = set() - for lhsStmt in self._g: + for lhsStmt in self.graph: log.debug(f'{INDENT*3} possibles for this lhs stmt: {lhsStmt}') for i, trueStmt in enumerate(sorted(workingSet)): log.debug(f'{INDENT*4} consider this true stmt ({i}): {trueStmt}') - bindingsFromStatement: Dict[Variable, Set[Node]] = {} - for lhsTerm, trueTerm in zip(lhsStmt, trueStmt): - if isinstance(lhsTerm, BNode): - lhsBnodes.add(lhsTerm) - elif isinstance(lhsTerm, Variable): - bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm) - elif lhsTerm != trueTerm: - break - else: - for v, vals in bindingsFromStatement.items(): - candidateTermMatches[v].update(vals) - for trueStmt in itertools.chain(workingSet, self._g): - for b in lhsBnodes: + for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt): + candidateTermMatches[v].update(vals) + + for trueStmt in itertools.chain(workingSet, self.graph): + for b in self.lhsBnodes: for t in [trueStmt[0], trueStmt[2]]: if isinstance(t, (URIRef, BNode)): candidateTermMatches[b].add(t) return candidateTermMatches + def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]: + """if these stmts match otherwise, what BNode or Variable mappings do we learn? + + e.g. stmt1=(?x B ?y) and stmt2=(A B C), then we yield (?x, {A}) and (?y, {C}) + or stmt1=(_:x B C) and stmt2=(A B C), then we yield (_:x, {A}) + or stmt1=(?x B C) and stmt2=(A B D), then we yield nothing + """ + bindingsFromStatement = {} + for term1, term2 in zip(stmt1, stmt2): + if isinstance(term1, (BNode, Variable)): + bindingsFromStatement.setdefault(term1, set()).add(term2) + elif term1 != term2: + break + else: + for v, vals in bindingsFromStatement.items(): + yield v, vals + def graphWithoutEvals(self, binding: CandidateBinding) -> Graph: g = Graph() usedByFuncs = binding.applyFunctions(self) - for stmt in self._g: + for stmt in self.graph: if stmt not in usedByFuncs: g.add(stmt) return g @@ -241,18 +232,25 @@ """some lhs statements need to be evaluated with a special function (e.g. math) and then not considered for the rest of the rule-firing process. It's like they already 'matched' something, so they don't need - to match a statement from the known-true working set.""" + to match a statement from the known-true working set. + + One Evaluation instance is for one function call. + """ @staticmethod def findEvals(lhs: Lhs) -> Iterator['Evaluation']: - for stmt in lhs._g.triples((None, MATH['sum'], None)): - # shouldn't be redoing this here - operands, operandsStmts = parseList(lhs._g, stmt[0]) + for stmt in lhs.graph.triples((None, MATH['sum'], None)): + operands, operandsStmts = parseList(lhs.graph, stmt[0]) g = Graph() g += operandsStmts yield Evaluation(operands, g, stmt) - for stmt in lhs._g.triples((None, ROOM['asFarenheit'], None)): + for stmt in lhs.graph.triples((None, MATH['greaterThan'], None)): + g = Graph() + g.add(stmt) + yield Evaluation([stmt[0], stmt[2]], g, stmt) + + for stmt in lhs.graph.triples((None, ROOM['asFarenheit'], None)): g = Graph() g.add(stmt) yield Evaluation([stmt[0]], g, stmt) @@ -260,13 +258,13 @@ # internal, use findEvals def __init__(self, operands: List[Node], operandsStmts: Graph, stmt: Triple) -> None: self.operands = operands - self.operandsStmts = operandsStmts + self.operandsStmts = operandsStmts # may grow self.stmt = stmt def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]: """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?""" pred = self.stmt[1] - objVar = self.stmt[2] + objVar: Node = self.stmt[2] boundOperands = [] for o in self.operands: if isinstance(o, Variable): @@ -277,43 +275,34 @@ boundOperands.append(o) - if not isinstance(objVar, Variable): - raise TypeError(f'expected Variable, got {objVar!r}') - if pred == MATH['sum']: - log.debug(f'{INDENT*4} sum {list(map(self.numericNode, boundOperands))}') - obj = cast(Literal, Literal(sum(map(self.numericNode, boundOperands)))) + obj = Literal(sum(map(numericNode, boundOperands))) self.operandsStmts.add(self.stmt) + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') return {objVar: obj}, self.operandsStmts elif pred == ROOM['asFarenheit']: if len(boundOperands) != 1: raise ValueError(":asFarenheit takes 1 subject operand") - f = Literal(Decimal(self.numericNode(boundOperands[0])) * 9 / 5 + 32) - g = Graph() - g.add(self.stmt) - - log.debug('made 1 st graph') - return {objVar: f}, g + f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32) + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + return {objVar: f}, self.operandsStmts + elif pred == MATH['greaterThan']: + if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])): + raise EvaluationFailed() + return {}, self.operandsStmts else: - raise NotImplementedError() - - def numericNode(self, n: Node): - if not isinstance(n, Literal): - raise TypeError(f'expected Literal, got {n=}') - val = n.toPython() - if not isinstance(val, (int, float, Decimal)): - raise TypeError(f'expected number, got {val=}') - return val + raise NotImplementedError(repr(pred)) -# merge into evaluation, raising a Invalid for impossible stmts -def mathTest(subj, pred, obj): - x = subj.toPython() - y = obj.toPython() - if pred == MATH['greaterThan']: - return x > y - else: - raise NotImplementedError(pred) +def numericNode(n: Node): + if not isinstance(n, Literal): + raise TypeError(f'expected Literal, got {n=}') + val = n.toPython() + if not isinstance(val, (int, float, Decimal)): + raise TypeError(f'expected number, got {val=}') + return val class Inference: @@ -334,7 +323,7 @@ workingSet = Graph() workingSet += graph - # just the statements that came from rule RHS's. + # just the statements that came from RHS's of rules that fired. implied = ConjunctiveGraph() bailout_iterations = 100 @@ -353,24 +342,29 @@ def _iterateAllRules(self, workingSet: Graph, implied: Graph): for i, r in enumerate(self.rules): - log.debug('') - log.debug(f'{INDENT*2} workingSet:') - for i, stmt in enumerate(sorted(workingSet)): - log.debug(f'{INDENT*3} ({i}) {stmt}') - - log.debug('') - log.debug(f'{INDENT*2}-applying rule {i}') - log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}') - log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}') + self.logRuleApplicationHeader(workingSet, i, r) if r[1] == LOG['implies']: applyRule(Lhs(r[0]), r[2], workingSet, implied) else: log.info(f'{INDENT*2} {r} not a rule?') + def logRuleApplicationHeader(self, workingSet, i, r): + if not log.isEnabledFor(logging.DEBUG): + return + + log.debug('') + log.debug(f'{INDENT*2} workingSet:') + for i, stmt in enumerate(sorted(workingSet)): + log.debug(f'{INDENT*3} ({i}) {stmt}') + + log.debug('') + log.debug(f'{INDENT*2}-applying rule {i}') + log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}') + log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}') + def applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph): for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])): - # log.debug(f' rule gave {binding=}') for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)): workingSet.add(lhsBoundStmt) for newStmt in binding.apply(rhs): @@ -384,8 +378,7 @@ out = [] used = set() cur = subj - while True: - # bug: mishandles empty list + while cur != RDF.nil: out.append(graph.value(cur, RDF.first)) used.add((cur, RDF.first, out[-1])) @@ -393,16 +386,13 @@ used.add((cur, RDF.rest, next)) cur = next - if cur == RDF.nil: - break return out, used def graphDump(g: Union[Graph, List[Triple]]): if not isinstance(g, Graph): g2 = Graph() - for stmt in g: - g2.add(stmt) + g2 += g g = g2 g.bind('', ROOM) g.bind('ex', Namespace('http://example.com/'))