Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1634:ba59cfc3c747
hack math:sum in there. Test suite is passing except some slow performers
author | drewp@bigasterisk.com |
---|---|
date | Sun, 12 Sep 2021 23:48:43 -0700 |
parents | 6107603ed455 |
children | 22d481f0a924 |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Sun Sep 12 21:48:36 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Sun Sep 12 23:48:43 2021 -0700 @@ -7,16 +7,16 @@ import time from collections import defaultdict from dataclasses import dataclass -from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast) from prometheus_client import Histogram, Summary -from rdflib import BNode, Graph, Namespace +from rdflib import RDF, BNode, Graph, Namespace from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate from rdflib.term import Literal, Node, Variable from candidate_binding import CandidateBinding -from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple) -from lhs_evaluation import Decimal, Evaluation, numericNode +from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple) +from lhs_evaluation import Decimal, Evaluation, numericNode, parseList log = logging.getLogger('infer') INDENT = ' ' @@ -58,6 +58,7 @@ lhsStmt: Triple prev: Optional['StmtLooper'] workingSet: ReadOnlyWorkingSet + parent: 'Lhs' # just for lhs.graph, really def __repr__(self): return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})' @@ -98,9 +99,29 @@ augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches, returnBoundStatementsOnly=False)) - log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}') + log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}') + + if self._advanceWithPlainMatches(augmentedWorkingSet): + return + + if self._advanceWithBoolRules(): + return + + curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) + [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) - log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements') + fullWorkingSet = self.workingSet + self.parent.graph + boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False)) + log.debug(f'{fullWorkingSet.__len__()=} {len(boundFullWorkingSet)=}') + + if self._advanceWithFunctions(augmentedWorkingSet, boundFullWorkingSet, lhsStmtBound): + return + + log.debug(f'{INDENT*6} {self} is past end') + self._pastEnd = True + + def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool: + log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements') for s in augmentedWorkingSet: log.debug(f'{INDENT*7} {s}') @@ -111,19 +132,38 @@ log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings') continue - log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}') + log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}') if outBinding.binding not in self._seenBindings: self._seenBindings.append(outBinding.binding.copy()) self._current = outBinding log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}') - return - log.debug(f'yes we saw') + return True + return False - log.debug(f'{INDENT*6} {self} mines rules') + def _advanceWithBoolRules(self) -> bool: + log.debug(f'{INDENT*7} {self} mines bool rules') + if self.lhsStmt[1] == MATH['greaterThan']: + operands = [self.lhsStmt[0], self.lhsStmt[2]] + try: + boundOperands = self._boundOperands(operands) + except BindingUnknown: + return False + if numericNode(boundOperands[0]) > numericNode(boundOperands[1]): + bindingDict: Dict[BindableTerm, + Node] = self._prevBindings().copy() # no new values; just allow matching to keep going + if bindingDict not in self._seenBindings: + self._seenBindings.append(bindingDict) + self._current = CandidateBinding(bindingDict) + log.debug(f'{INDENT*7} new binding from {self} -> {bindingDict}') + return True + return False + + def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: + log.debug(f'{INDENT*7} {self} mines rules') if self.lhsStmt[1] == ROOM['asFarenheit']: pb: Dict[BindableTerm, Node] = self._prevBindings() - log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') + log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') if self.lhsStmt[0] in pb: operands = [pb[cast(BindableTerm, self.lhsStmt[0])]] @@ -136,17 +176,59 @@ if newBindings not in self._seenBindings: self._seenBindings.append(newBindings) self._current = CandidateBinding(newBindings) - return + return True + elif self.lhsStmt[1] == MATH['sum']: + + g = Graph() + for s in boundFullWorkingSet: + g.add(s) + log.debug(f' boundWorkingSet graph: {s}') + log.debug(f'_parseList subj = {lhsStmtBound[0]}') + operands, _ = parseList(g, lhsStmtBound[0]) + log.debug(f'********* {INDENT*7} {self} found list {operands=}') + try: + obj = Literal(sum(map(numericNode, operands))) + except TypeError: + log.debug('typeerr in operands') + pass + else: + objVar = lhsStmtBound[2] + log.debug(f'{objVar=}') - log.debug(f'{INDENT*6} {self} is past end') - self._pastEnd = True + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + newBindings: Dict[BindableTerm, Node] = {objVar: obj} + log.debug(f'{newBindings=}') + + self._current.addNewBindings(CandidateBinding(newBindings)) + log.debug(f'{self._seenBindings=}') + if newBindings not in self._seenBindings: + self._seenBindings.append(newBindings) + self._current = CandidateBinding(newBindings) + return True + + return False + + def _boundOperands(self, operands) -> List[Node]: + pb: Dict[BindableTerm, Node] = self._prevBindings() + + boundOperands: List[Node] = [] + for op in operands: + if isinstance(op, (Variable, BNode)): + if op in pb: + boundOperands.append(pb[op]) + else: + raise BindingUnknown() + else: + boundOperands.append(op) + return boundOperands def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding: outBinding = self._prevBindings().copy() for rt, ct in zip(self.lhsStmt, newStmt): if isinstance(rt, (Variable, BNode)): if rt in outBinding and outBinding[rt] != ct: - raise Inconsistent() + raise Inconsistent(f'{rt=} {ct=} {outBinding=}') outBinding[rt] = ct return CandidateBinding(outBinding) @@ -245,7 +327,21 @@ """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all start out valid (or else raise NoOptions)""" - stmtsToAdd = list(self.graph) + usedByFuncs: Set[Triple] = set() # don't worry about matching these + stmtsToResolve = list(self.graph) + for i, s in enumerate(stmtsToResolve): + if s[1] == MATH['sum']: + _, used = parseList(self.graph, s[0]) + usedByFuncs.update(used) + + stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs] + + # sort them by variable dependencies; don't just try all perms! + def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. + (s, p, o) = stmt + return p == MATH['sum'], p, s, o + + stmtsToAdd.sort(key=lightSortKey) for perm in itertools.permutations(stmtsToAdd): stmtStack: List[StmtLooper] = [] @@ -254,7 +350,7 @@ for s in perm: try: - elem = StmtLooper(s, prev, knownTrue) + elem = StmtLooper(s, prev, knownTrue, parent=self) except NoOptions: log.debug(f'{INDENT*6} permutation didnt work, try another') break @@ -540,7 +636,7 @@ log.debug('') log.debug(f'{INDENT*2}-applying rule {i}') log.debug(f'{INDENT*3} rule def lhs:') - for stmt in r.lhsGraph: + for stmt in sorted(r.lhsGraph, reverse=True): log.debug(f'{INDENT*4} {stmt}') log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')