Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1607:b21885181e35
more modules, types. Maybe less repeated computation on BoundLhs
author | drewp@bigasterisk.com |
---|---|
date | Mon, 06 Sep 2021 15:38:48 -0700 |
parents | 449746d1598f |
children | f928eb06a4f6 |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Mon Sep 06 01:15:14 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Mon Sep 06 15:38:48 2021 -0700 @@ -7,24 +7,20 @@ import time from collections import defaultdict from dataclasses import dataclass, field -from decimal import Decimal -from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union, cast +from typing import Dict, Iterator, List, Set, Tuple, Union, cast from prometheus_client import Summary -from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef +from rdflib import BNode, Graph, Namespace, URIRef from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate from rdflib.term import Node, Variable -from lhs_evaluation import EvaluationFailed, Evaluation +from candidate_binding import CandidateBinding +from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) +from lhs_evaluation import Evaluation log = logging.getLogger('infer') INDENT = ' ' -Triple = Tuple[Node, Node, Node] -Rule = Tuple[Graph, Node, Graph] -BindableTerm = Union[Variable, BNode] -ReadOnlyWorkingSet = ReadOnlyGraphAggregate - INFER_CALLS = Summary('read_rules_calls', 'calls') ROOM = Namespace("http://projects.bigasterisk.com/room/") @@ -36,110 +32,9 @@ GRAPH_ID = URIRef('dont/care') -class BindingUnknown(ValueError): - """e.g. we were asked to make the bound version - of (A B ?c) and we don't have a binding for ?c - """ - - -@dataclass -class CandidateBinding: - binding: Dict[BindableTerm, Node] - - def __repr__(self): - b = " ".join("%s=%s" % (k, v) for k, v in sorted(self.binding.items())) - return f'CandidateBinding({b})' - - def apply(self, g: Graph) -> Iterator[Triple]: - for stmt in g: - try: - bound = (self._applyTerm(stmt[0]), self._applyTerm(stmt[1]), self._applyTerm(stmt[2])) - except BindingUnknown: - continue - yield bound - - def _applyTerm(self, term: Node): - if isinstance(term, (Variable, BNode)): - if term in self.binding: - return self.binding[term] - else: - raise BindingUnknown() - return term - - def applyFunctions(self, lhs) -> Graph: - """may grow the binding with some results""" - usedByFuncs = Graph(identifier=GRAPH_ID) - while True: - delta = self._applyFunctionsIteration(lhs, usedByFuncs) - if delta == 0: - break - return usedByFuncs - - def _applyFunctionsIteration(self, lhs, usedByFuncs: Graph): - before = len(self.binding) - delta = 0 - for ev in lhs.evaluations: - log.debug(f'{INDENT*3} found Evaluation') - - newBindings, usedGraph = ev.resultBindings(self.binding) - usedByFuncs += usedGraph - self._addNewBindings(newBindings) - delta = len(self.binding) - before - if log.isEnabledFor(logging.DEBUG): - dump = "(...)" - if cast(int, usedGraph.__len__()) < 20: - dump = graphDump(usedGraph) - log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings') - return delta - - def _addNewBindings(self, newBindings): - for k, v in newBindings.items(): - if k in self.binding and self.binding[k] != v: - raise ValueError(f'conflict- thought {k} would be {self.binding[k]} but another Evaluation said it should be {v}') - self.binding[k] = v - - def verify(self, lhs: 'Lhs', workingSet: ReadOnlyWorkingSet, usedByFuncs: Graph) -> bool: - """Can this lhs be true all at once in workingSet? Does it match with these bindings?""" - boundLhs = list(self.apply(lhs.graph)) - boundUsedByFuncs = list(self.apply(usedByFuncs)) - - self._logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs) - - for stmt in boundLhs: - log.debug(f'{INDENT*4} check for {stmt}') - - if stmt in boundUsedByFuncs: - pass - elif stmt in workingSet: - pass - else: - log.debug(f'{INDENT*5} stmt not known to be true') - return False - return True - - def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs): - if not log.isEnabledFor(logging.DEBUG): - return - log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:') - for stmt in sorted(boundLhs): - log.debug(f'{INDENT*4}|{INDENT} {stmt}') - - # log.debug(f'{INDENT*4}| and against this workingSet:') - # for stmt in sorted(workingSet): - # log.debug(f'{INDENT*4}|{INDENT} {stmt}') - - stmts = sorted(boundUsedByFuncs) - if stmts: - log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:') - for stmt in stmts: - log.debug(f'{INDENT*4}|{INDENT} {stmt}') - log.debug(f'{INDENT*4}\\') - - @dataclass class Lhs: graph: Graph - stats: Dict staticRuleStmts: Graph = field(default_factory=Graph) lhsBindables: Set[BindableTerm] = field(default_factory=set) @@ -155,42 +50,41 @@ self.evaluations = list(Evaluation.findEvals(self.graph)) - def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]: + def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: """bindings that fit the LHS of a rule, using statements from workingSet and functions from LHS""" log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') - self.stats['findCandidateBindingsCalls'] += 1 + stats['findCandidateBindingsCalls'] += 1 if not self._allStaticStatementsMatch(workingSet): - self.stats['findCandidateBindingEarlyExits'] += 1 + stats['findCandidateBindingEarlyExits'] += 1 return + for binding in self._possibleBindings(workingSet, stats): + log.debug('') + log.debug(f'{INDENT*4}*trying {binding.binding}') + + if not binding.verify(workingSet): + log.debug(f'{INDENT*4} this binding did not verify') + stats['permCountFailingVerify'] += 1 + continue + + stats['permCountSucceeding'] += 1 + yield binding + + def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']: + """this yields at least the working bindings, and possibly others""" candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet) orderedVars, orderedValueSets = _organize(candidateTermMatches) - self._logCandidates(orderedVars, orderedValueSets) log.debug(f'{INDENT*3} trying all permutations:') - for perm in itertools.product(*orderedValueSets): - binding = CandidateBinding(dict(zip(orderedVars, perm))) - log.debug('') - log.debug(f'{INDENT*4}*trying {binding}') - try: - usedByFuncs = binding.applyFunctions(self) + yield BoundLhs(self, CandidateBinding(dict(zip(orderedVars, perm)))) except EvaluationFailed: - self.stats['permCountFailingEval'] += 1 - continue - - if not binding.verify(self, workingSet, usedByFuncs): - log.debug(f'{INDENT*4} this binding did not verify') - self.stats['permCountFailingVerify'] += 1 - continue - - self.stats['permCountSucceeding'] += 1 - yield binding + stats['permCountFailingEval'] += 1 def _allStaticStatementsMatch(self, workingSet: ReadOnlyWorkingSet) -> bool: for ruleStmt in self.staticRuleStmts: @@ -236,15 +130,6 @@ log.debug(f'{INDENT*5} {v=} {vals=}') yield v, vals - def graphWithoutEvals(self, binding: CandidateBinding) -> Graph: - g = Graph(identifier=GRAPH_ID) - usedByFuncs = binding.applyFunctions(self) - - for stmt in self.graph: - if stmt not in usedByFuncs: - g.add(stmt) - return g - def _logCandidates(self, orderedVars, orderedValueSets): if not log.isEnabledFor(logging.DEBUG): return @@ -255,16 +140,106 @@ log.debug(f'{INDENT*5}{val!r}') +@dataclass +class BoundLhs: + lhs: Lhs + binding: CandidateBinding + + def __post_init__(self): + self.usedByFuncs = Graph(identifier=GRAPH_ID) + self.graphWithoutEvals = self._graphWithoutEvals() + + def _graphWithoutEvals(self) -> Graph: + g = Graph(identifier=GRAPH_ID) + self._applyFunctions() + + for stmt in self.lhs.graph: + if stmt not in self.usedByFuncs: + g.add(stmt) + return g + + def _applyFunctions(self): + """may grow the binding with some results""" + while True: + delta = self._applyFunctionsIteration() + if delta == 0: + break + + def _applyFunctionsIteration(self): + before = len(self.binding.binding) + delta = 0 + for ev in self.lhs.evaluations: + log.debug(f'{INDENT*3} found Evaluation') + + newBindings, usedGraph = ev.resultBindings(self.binding) + self.usedByFuncs += usedGraph + self.binding.addNewBindings(newBindings) + delta = len(self.binding.binding) - before + if log.isEnabledFor(logging.DEBUG): + dump = "(...)" + if cast(int, usedGraph.__len__()) < 20: + dump = graphDump(usedGraph) + log.debug(f'{INDENT*4} rule {dump} made {delta} new bindings') + return delta + + + def verify(self, workingSet: ReadOnlyWorkingSet) -> bool: + """Can this bound lhs be true all at once in workingSet?""" + boundLhs = list(self.binding.apply(self.lhs.graph)) + boundUsedByFuncs = list(self.binding.apply(self.usedByFuncs)) + + self._logVerifyBanner(boundLhs, workingSet, boundUsedByFuncs) + + for stmt in boundLhs: + log.debug(f'{INDENT*4} check for {stmt}') + + if stmt in boundUsedByFuncs: + pass + elif stmt in workingSet: + pass + else: + log.debug(f'{INDENT*5} stmt not known to be true') + return False + return True + + def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet, boundUsedByFuncs): + if not log.isEnabledFor(logging.DEBUG): + return + log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:') + for stmt in sorted(boundLhs): + log.debug(f'{INDENT*4}|{INDENT} {stmt}') + + # log.debug(f'{INDENT*4}| and against this workingSet:') + # for stmt in sorted(workingSet): + # log.debug(f'{INDENT*4}|{INDENT} {stmt}') + + stmts = sorted(boundUsedByFuncs) + if stmts: + log.debug(f'{INDENT*4}| while ignoring these usedByFuncs:') + for stmt in stmts: + log.debug(f'{INDENT*4}|{INDENT} {stmt}') + log.debug(f'{INDENT*4}\\') + + +@dataclass +class Rule: + lhsGraph: Graph + rhsGraph: Graph + + def __post_init__(self): + self.lhs = Lhs(self.lhsGraph) + + class Inference: def __init__(self) -> None: - self.rules = ConjunctiveGraph() + self.rules = [] def setRules(self, g: ConjunctiveGraph): - self.rules = ConjunctiveGraph() + self.rules: List[Rule] = [] for stmt in g: if stmt[1] == LOG['implies']: - self.rules.add(stmt) + self.rules.append(Rule(stmt[0], stmt[2])) # others should go to a default working set? @INFER_CALLS.time() @@ -274,7 +249,7 @@ """ log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:') startTime = time.time() - self.stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0) + stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0) # everything that is true: the input graph, plus every rule conclusion we can make workingSet = Graph() workingSet += graph @@ -284,28 +259,28 @@ bailout_iterations = 100 delta = 1 - self.stats['initWorkingSet'] = cast(int, workingSet.__len__()) + stats['initWorkingSet'] = cast(int, workingSet.__len__()) while delta > 0 and bailout_iterations > 0: log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)') bailout_iterations -= 1 delta = -len(implied) - self._iterateAllRules(workingSet, implied) + self._iterateAllRules(workingSet, implied, stats) delta += len(implied) - self.stats['iterations'] += 1 + stats['iterations'] += 1 log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts') - self.stats['timeSpent'] = round(time.time() - startTime, 3) - self.stats['impliedStmts'] = len(implied) - log.info(f'{INDENT*0} Inference done {dict(self.stats)}. Implied:') + stats['timeSpent'] = round(time.time() - startTime, 3) + stats['impliedStmts'] = len(implied) + log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:') for st in implied: log.info(f'{INDENT*1} {st}') return implied - def _iterateAllRules(self, workingSet: Graph, implied: Graph): + def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats): for i, r in enumerate(self.rules): self._logRuleApplicationHeader(workingSet, i, r) - _applyRule(Lhs(r[0], self.stats), r[2], workingSet, implied, self.stats) + _applyRule(r.lhs, r.rhsGraph, workingSet, implied, stats) - def _logRuleApplicationHeader(self, workingSet, i, r): + def _logRuleApplicationHeader(self, workingSet, i, r: Rule): if not log.isEnabledFor(logging.DEBUG): return @@ -316,18 +291,18 @@ log.debug('') log.debug(f'{INDENT*2}-applying rule {i}') - log.debug(f'{INDENT*3} rule def lhs: {graphDump(r[0])}') - log.debug(f'{INDENT*3} rule def rhs: {graphDump(r[2])}') + log.debug(f'{INDENT*3} rule def lhs: {graphDump(r.lhsGraph)}') + log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}') def _applyRule(lhs: Lhs, rhs: Graph, workingSet: Graph, implied: Graph, stats: Dict): - for binding in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet])): + for bound in lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): log.debug(f'{INDENT*3} rule has a working binding:') - for lhsBoundStmt in binding.apply(lhs.graphWithoutEvals(binding)): + for lhsBoundStmt in bound.binding.apply(bound.graphWithoutEvals): log.debug(f'{INDENT*5} adding {lhsBoundStmt=}') workingSet.add(lhsBoundStmt) - for newStmt in binding.apply(rhs): + for newStmt in bound.binding.apply(rhs): log.debug(f'{INDENT*5} adding {newStmt=}') workingSet.add(newStmt) implied.add(newStmt) @@ -335,6 +310,7 @@ def graphDump(g: Union[Graph, List[Triple]]): if not isinstance(g, Graph): + log.warning(f"it's a {type(g)}") g2 = Graph() g2 += g g = g2