Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/inference.py @ 1637:ec3f98d0c1d8
refactor rules eval
author | drewp@bigasterisk.com |
---|---|
date | Mon, 13 Sep 2021 01:36:06 -0700 |
parents | 3252bdc284bc |
children | 0ba1625037ae |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Mon Sep 13 00:18:47 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Mon Sep 13 01:36:06 2021 -0700 @@ -12,11 +12,11 @@ from prometheus_client import Histogram, Summary from rdflib import RDF, BNode, Graph, Namespace from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate -from rdflib.term import Literal, Node, Variable +from rdflib.term import Node, Variable from candidate_binding import CandidateBinding -from inference_types import (BindableTerm, BindingUnknown, ReadOnlyWorkingSet, Triple) -from lhs_evaluation import Decimal, numericNode, parseList +from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple +from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs log = logging.getLogger('infer') INDENT = ' ' @@ -104,9 +104,6 @@ if self._advanceWithPlainMatches(augmentedWorkingSet): return - if self._advanceWithBoolRules(): - return - curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) @@ -125,7 +122,7 @@ for s in augmentedWorkingSet: log.debug(f'{INDENT*7} {s}') - for i, stmt in enumerate(augmentedWorkingSet): + for stmt in augmentedWorkingSet: try: outBinding = self._totalBindingIfThisStmtWereTrue(stmt) except Inconsistent: @@ -140,71 +137,24 @@ return True return False - def _advanceWithBoolRules(self) -> bool: - log.debug(f'{INDENT*7} {self} mines bool rules') - if self.lhsStmt[1] == MATH['greaterThan']: - operands = [self.lhsStmt[0], self.lhsStmt[2]] - try: - boundOperands = self._boundOperands(operands) - except BindingUnknown: - return False - if numericNode(boundOperands[0]) > numericNode(boundOperands[1]): - binding: CandidateBinding = self._prevBindings().copy() # no new values; just allow matching to keep going - if binding not in self._seenBindings: - self._seenBindings.append(binding) - self._current = binding - log.debug(f'{INDENT*7} new binding from {self} -> {binding}') - return True - return False - def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: - log.debug(f'{INDENT*7} {self} mines rules') - - if self.lhsStmt[1] == ROOM['asFarenheit']: - pb: CandidateBinding = self._prevBindings() - log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') + pred: Node = self.lhsStmt[1] - if isinstance(self.lhsStmt[0], (Variable, BNode)) and pb.contains(self.lhsStmt[0]): - operands = [pb.applyTerm(self.lhsStmt[0])] - f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) - objVar = self.lhsStmt[2] - if not isinstance(objVar, Variable): - raise TypeError(f'expected Variable, got {objVar!r}') - newBindings = CandidateBinding({cast(BindableTerm, objVar): cast(Node, f)}) - self._current.addNewBindings(newBindings) - if newBindings not in self._seenBindings: - self._seenBindings.append(newBindings) - self._current = newBindings - return True - elif self.lhsStmt[1] == MATH['sum']: - - g = Graph() - for s in boundFullWorkingSet: - g.add(s) - log.debug(f' boundWorkingSet graph: {s}') - log.debug(f'_parseList subj = {lhsStmtBound[0]}') - operands, _ = parseList(g, lhsStmtBound[0]) - log.debug(f'********* {INDENT*7} {self} found list {operands=}') + for functionType in functionsFor(pred): + fn = functionType(self.lhsStmt, self.parent.graph) try: - obj = Literal(sum(map(numericNode, operands))) - except TypeError: - log.debug('typeerr in operands') + out = fn.bind(self._prevBindings()) + except BindingUnknown: pass else: - objVar = lhsStmtBound[2] - log.debug(f'{objVar=}') - - if not isinstance(objVar, Variable): - raise TypeError(f'expected Variable, got {objVar!r}') - newBindings = CandidateBinding({objVar: obj}) - log.debug(f'{newBindings=}') - - self._current.addNewBindings(newBindings) - log.debug(f'{self._seenBindings=}') - if newBindings not in self._seenBindings: - self._seenBindings.append(newBindings) - self._current = newBindings - return True + if out is not None: + binding: CandidateBinding = self._prevBindings().copy() + binding.addNewBindings(out) + if binding not in self._seenBindings: + self._seenBindings.append(binding) + self._current = binding + log.debug(f'{INDENT*7} new binding from {self} -> {binding}') + return True return False @@ -302,14 +252,9 @@ """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all start out valid (or else raise NoOptions)""" - usedByFuncs: Set[Triple] = set() # don't worry about matching these - stmtsToResolve = list(self.graph) - for i, s in enumerate(stmtsToResolve): - if s[1] == MATH['sum']: - _, used = parseList(self.graph, s[0]) - usedByFuncs.update(used) + usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) - stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs] + stmtsToAdd = list(self.graph - usedByFuncs) # sort them by variable dependencies; don't just try all perms! def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. @@ -478,6 +423,9 @@ def graphDump(g: Union[Graph, List[Triple]]): + # this is very slow- debug only! + if not log.isEnabledFor(logging.DEBUG): + return "(skipped dump)" if not isinstance(g, Graph): g2 = Graph() g2 += g