Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1631:2c85a4f5dd9c
big rewrite of infer() using statements not variables as the things to iterate over
author | drewp@bigasterisk.com |
---|---|
date | Sun, 12 Sep 2021 04:32:52 -0700 |
parents | ea559a846714 |
children | bd79a2941cab |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Sat Sep 11 23:33:55 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Sun Sep 12 04:32:52 2021 -0700 @@ -7,16 +7,16 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import Dict, Iterator, List, Set, Tuple, Union, cast +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast -from prometheus_client import Summary, Histogram -from rdflib import BNode, Graph, Namespace, URIRef +from prometheus_client import Histogram, Summary +from rdflib import BNode, Graph, Namespace from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate -from rdflib.term import Node, Variable +from rdflib.term import Literal, Node, Variable from candidate_binding import CandidateBinding from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) -from lhs_evaluation import Evaluation +from lhs_evaluation import Decimal, Evaluation, numericNode log = logging.getLogger('infer') INDENT = ' ' @@ -29,27 +29,141 @@ MATH = Namespace('http://www.w3.org/2000/10/swap/math#') +def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]: + return ( + None if isinstance(stmt[0], (Variable, BNode)) else stmt[0], + None if isinstance(stmt[1], (Variable, BNode)) else stmt[1], + None if isinstance(stmt[2], (Variable, BNode)) else stmt[2], + ) + + +class NoOptions(ValueError): + """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply""" + + +class Inconsistent(ValueError): + """adding this stmt would be inconsistent with an existing binding""" + + +@dataclass +class StmtLooper: + lhsStmt: Triple + prev: Optional['StmtLooper'] + workingSet: ReadOnlyWorkingSet + + def __repr__(self): + return f'StmtLooper({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' + + def __post_init__(self): + self._myWorkingSetMatches = self._myMatches(self.workingSet) + + self._current = CandidateBinding({}) + self._pastEnd = False + self._seenBindings: List[Dict[BindableTerm, Node]] = [] + self.restart() + + def _myMatches(self, g: Graph) -> List[Triple]: + template = stmtTemplate(self.lhsStmt) + + stmts = sorted(cast(Iterator[Triple], list(g.triples(template)))) + # plus new lhs possibilties... + # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}') + + return stmts + + def _prevBindings(self) -> Dict[BindableTerm, Node]: + if not self.prev or self.prev.pastEnd(): + return {} + + return self.prev.currentBinding().binding + + def advance(self): + """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode""" + log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements') + for i, stmt in enumerate(self._myWorkingSetMatches): + try: + outBinding = self._totalBindingIfThisStmtWereTrue(stmt) + except Inconsistent: + log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') + continue + log.debug(f'seen {outBinding.binding} in {self._seenBindings}') + if outBinding.binding not in self._seenBindings: + self._seenBindings.append(outBinding.binding.copy()) + log.debug(f'no, adding') + self._current = outBinding + log.debug(f'{INDENT*7} {self} - Looper matches {stmt} which tells us {outBinding}') + return + log.debug(f'yes we saw') + + log.debug(f'{INDENT*6} {self} mines rules') + + if self.lhsStmt[1] == ROOM['asFarenheit']: + pb: Dict[BindableTerm, Node] = self._prevBindings() + if self.lhsStmt[0] in pb: + operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] + f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) + objVar = self.lhsStmt[2] + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + newBindings = {cast(BindableTerm, objVar): cast(Node, f)} + self._current.addNewBindings(CandidateBinding(newBindings)) + if newBindings not in self._seenBindings: + self._seenBindings.append(newBindings) + self._current = CandidateBinding(newBindings) + + log.debug(f'{INDENT*6} {self} is past end') + self._pastEnd = True + + def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: + outBinding = self._prevBindings().copy() + for rt, ct in zip(self.lhsStmt, newStmt): + if isinstance(rt, (Variable, BNode)): + if rt in outBinding and outBinding[rt] != ct: + raise Inconsistent() + outBinding[rt] = ct + return CandidateBinding(outBinding) + + def currentBinding(self) -> CandidateBinding: + if self.pastEnd(): + raise NotImplementedError() + return self._current + + def newLhsStmts(self) -> List[Triple]: + """under the curent bindings, what new stmts beyond workingSet are also true? includes all `prev`""" + return [] + + def pastEnd(self) -> bool: + return self._pastEnd + + def restart(self): + self._pastEnd = False + self._seenBindings = [] + self.advance() + if self.pastEnd(): + raise NoOptions() + + @dataclass class Lhs: graph: Graph def __post_init__(self): # do precomputation in here that's not specific to the workingSet - self.staticRuleStmts = Graph() - self.nonStaticRuleStmts = Graph() + # self.staticRuleStmts = Graph() + # self.nonStaticRuleStmts = Graph() - self.lhsBindables: Set[BindableTerm] = set() - self.lhsBnodes: Set[BNode] = set() - for ruleStmt in self.graph: - varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] - self.lhsBindables.update(varsAndBnodesInStmt) - self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) - if not varsAndBnodesInStmt: - self.staticRuleStmts.add(ruleStmt) - else: - self.nonStaticRuleStmts.add(ruleStmt) + # self.lhsBindables: Set[BindableTerm] = set() + # self.lhsBnodes: Set[BNode] = set() + # for ruleStmt in self.graph: + # varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))] + # self.lhsBindables.update(varsAndBnodesInStmt) + # self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode)) + # if not varsAndBnodesInStmt: + # self.staticRuleStmts.add(ruleStmt) + # else: + # self.nonStaticRuleStmts.add(ruleStmt) - self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts) + # self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts) self.evaluations = list(Evaluation.findEvals(self.graph)) @@ -59,24 +173,69 @@ def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']: """bindings that fit the LHS of a rule, using statements from workingSet and functions from LHS""" - log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}') - stats['findCandidateBindingsCalls'] += 1 + log.debug(f'{INDENT*4} build new StmtLooper stack') - if not self._allStaticStatementsMatch(knownTrue): - stats['findCandidateBindingEarlyExits'] += 1 + stmtStack: List[StmtLooper] = [] + try: + prev: Optional[StmtLooper] = None + for s in sorted(self.graph): # order of this matters! :( + stmtStack.append(StmtLooper(s, prev, knownTrue)) + prev = stmtStack[-1] + except NoOptions: + log.debug(f'{INDENT*5} no options; 0 bindings') return - for binding in self._possibleBindings(knownTrue, stats): - log.debug('') - log.debug(f'{INDENT*4}*trying {binding.binding}') + log.debug(f'{INDENT*5} initial odometer:') + for l in stmtStack: + log.debug(f'{INDENT*6} {l}') + + if any(ring.pastEnd() for ring in stmtStack): + log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}') + + raise NoOptions() + sl = stmtStack[-1] + iterCount = 0 + while True: + iterCount += 1 + if iterCount > 10: + raise ValueError('stuck') + + log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}') + + log.debug(f'{INDENT*5} <<<') + yield BoundLhs(self, sl.currentBinding()) + log.debug(f'{INDENT*5} >>>') + + log.debug(f'{INDENT*5} odometer:') + for l in stmtStack: + log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') - if not binding.verify(knownTrue): - log.debug(f'{INDENT*4} this binding did not verify') - stats['permCountFailingVerify'] += 1 - continue + done = self._advanceAll(stmtStack) + + log.debug(f'{INDENT*5} odometer after ({done=}):') + for l in stmtStack: + log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}') + + log.debug(f'{INDENT*4} ^^ findCandBindings iteration done') + if done: + break - stats['permCountSucceeding'] += 1 - yield binding + def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool: + carry = True # 1st elem always must advance + for i, ring in enumerate(stmtStack): + # unlike normal odometer, advancing any earlier ring could invalidate later ones + if carry: + log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance') + ring.advance() + carry = False + if ring.pastEnd(): + if ring is stmtStack[-1]: + log.debug(f'{INDENT*5} advanceAll [{i}] {ring} says we done') + return True + log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart') + ring.restart() + carry = True + return False def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool: # bug: see TestSelfFulfillingRule.test3 for a case where this rule's @@ -162,41 +321,6 @@ log.debug(f'{INDENT*5}{val!r}') -# @dataclass -# class CandidateTermMatches: -# """lazily find the possible matches for this term""" -# terms: List[BindableTerm] -# lhs: Lhs -# knownTrue: Graph -# boundSoFar: CandidateBinding - -# def __post_init__(self): -# self.results: List[Node] = [] # we have to be able to repeat the results - -# res: Set[Node] = set() -# for trueStmt in self.knownTrue: # all bound -# lStmts = list(self.lhsStmtsContainingTerm()) -# log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}') -# for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False): -# log.debug(f'{INDENT*4} {pat=}') -# implied = self._stmtImplies(pat, trueStmt) -# if implied is not None: -# res.add(implied) -# self.results = list(res) -# # self.results.sort() - -# log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}') - -# def lhsStmtsContainingTerm(self): -# # lhs could precompute this -# for lhsStmt in self.lhs.graph: -# if self.term in lhsStmt: -# yield lhsStmt - -# def __iter__(self): -# return iter(self.results) - - @dataclass class BoundLhs: lhs: Lhs @@ -204,7 +328,7 @@ def __post_init__(self): self.usedByFuncs = Graph() - self._applyFunctions() + # self._applyFunctions() def lhsStmtsWithoutEvals(self): for stmt in self.lhs.graph: @@ -263,19 +387,40 @@ class Rule: lhsGraph: Graph rhsGraph: Graph - + def __post_init__(self): self.lhs = Lhs(self.lhsGraph) + # + self.rhsBnodeMap = {} def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict): for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats): - log.debug(f'{INDENT*3} rule has a working binding:') + log.debug(f'{INDENT*5} +rule has a working binding: {bound}') + + # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do + existingRhsBnodes = set() + for stmt in self.rhsGraph: + for t in stmt: + if isinstance(t, BNode): + existingRhsBnodes.add(t) + # if existingRhsBnodes: + # log.debug(f'{INDENT*6} mapping rhs bnodes {existingRhsBnodes} to new ones') + + for b in existingRhsBnodes: - for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): - log.debug(f'{INDENT*4} adding {lhsBoundStmt=}') - workingSet.add(lhsBoundStmt) + key = tuple(sorted(bound.binding.binding.items())), b + self.rhsBnodeMap.setdefault(key, BNode()) + + + bound.binding.addNewBindings(CandidateBinding({b: self.rhsBnodeMap[key]})) + + # for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()): + # log.debug(f'{INDENT*6} adding to workingSet {lhsBoundStmt=}') + # workingSet.add(lhsBoundStmt) + # log.debug(f'{INDENT*6} rhsGraph is good: {list(self.rhsGraph)}') + for newStmt in bound.binding.apply(self.rhsGraph): - log.debug(f'{INDENT*4} adding {newStmt=}') + # log.debug(f'{INDENT*6} adding {newStmt=}') workingSet.add(newStmt) implied.add(newStmt) @@ -350,7 +495,6 @@ def graphDump(g: Union[Graph, List[Triple]]): if not isinstance(g, Graph): - log.warning(f"it's a {type(g)}") g2 = Graph() g2 += g g = g2