Mercurial > code > home > repos > homeauto
changeset 1637:ec3f98d0c1d8
refactor rules eval
author | drewp@bigasterisk.com |
---|---|
date | Mon, 13 Sep 2021 01:36:06 -0700 |
parents | 3252bdc284bc |
children | 0ba1625037ae |
files | service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/lhs_evaluation.py |
diffstat | 2 files changed, 137 insertions(+), 77 deletions(-) [+] |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py Mon Sep 13 00:18:47 2021 -0700 +++ b/service/mqtt_to_rdf/inference.py Mon Sep 13 01:36:06 2021 -0700 @@ -12,11 +12,11 @@ from prometheus_client import Histogram, Summary from rdflib import RDF, BNode, Graph, Namespace from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate -from rdflib.term import Literal, Node, Variable +from rdflib.term import Node, Variable from candidate_binding import CandidateBinding -from inference_types import (BindableTerm, BindingUnknown, ReadOnlyWorkingSet, Triple) -from lhs_evaluation import Decimal, numericNode, parseList +from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple +from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs log = logging.getLogger('infer') INDENT = ' ' @@ -104,9 +104,6 @@ if self._advanceWithPlainMatches(augmentedWorkingSet): return - if self._advanceWithBoolRules(): - return - curBind = self.prev.currentBinding() if self.prev else CandidateBinding({}) [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False) @@ -125,7 +122,7 @@ for s in augmentedWorkingSet: log.debug(f'{INDENT*7} {s}') - for i, stmt in enumerate(augmentedWorkingSet): + for stmt in augmentedWorkingSet: try: outBinding = self._totalBindingIfThisStmtWereTrue(stmt) except Inconsistent: @@ -140,71 +137,24 @@ return True return False - def _advanceWithBoolRules(self) -> bool: - log.debug(f'{INDENT*7} {self} mines bool rules') - if self.lhsStmt[1] == MATH['greaterThan']: - operands = [self.lhsStmt[0], self.lhsStmt[2]] - try: - boundOperands = self._boundOperands(operands) - except BindingUnknown: - return False - if numericNode(boundOperands[0]) > numericNode(boundOperands[1]): - binding: CandidateBinding = self._prevBindings().copy() # no new values; just allow matching to keep going - if binding not in self._seenBindings: - self._seenBindings.append(binding) - self._current = binding - log.debug(f'{INDENT*7} new binding from {self} -> {binding}') - return True - return False - def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool: - log.debug(f'{INDENT*7} {self} mines rules') - - if self.lhsStmt[1] == ROOM['asFarenheit']: - pb: CandidateBinding = self._prevBindings() - log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}') + pred: Node = self.lhsStmt[1] - if isinstance(self.lhsStmt[0], (Variable, BNode)) and pb.contains(self.lhsStmt[0]): - operands = [pb.applyTerm(self.lhsStmt[0])] - f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32)) - objVar = self.lhsStmt[2] - if not isinstance(objVar, Variable): - raise TypeError(f'expected Variable, got {objVar!r}') - newBindings = CandidateBinding({cast(BindableTerm, objVar): cast(Node, f)}) - self._current.addNewBindings(newBindings) - if newBindings not in self._seenBindings: - self._seenBindings.append(newBindings) - self._current = newBindings - return True - elif self.lhsStmt[1] == MATH['sum']: - - g = Graph() - for s in boundFullWorkingSet: - g.add(s) - log.debug(f' boundWorkingSet graph: {s}') - log.debug(f'_parseList subj = {lhsStmtBound[0]}') - operands, _ = parseList(g, lhsStmtBound[0]) - log.debug(f'********* {INDENT*7} {self} found list {operands=}') + for functionType in functionsFor(pred): + fn = functionType(self.lhsStmt, self.parent.graph) try: - obj = Literal(sum(map(numericNode, operands))) - except TypeError: - log.debug('typeerr in operands') + out = fn.bind(self._prevBindings()) + except BindingUnknown: pass else: - objVar = lhsStmtBound[2] - log.debug(f'{objVar=}') - - if not isinstance(objVar, Variable): - raise TypeError(f'expected Variable, got {objVar!r}') - newBindings = CandidateBinding({objVar: obj}) - log.debug(f'{newBindings=}') - - self._current.addNewBindings(newBindings) - log.debug(f'{self._seenBindings=}') - if newBindings not in self._seenBindings: - self._seenBindings.append(newBindings) - self._current = newBindings - return True + if out is not None: + binding: CandidateBinding = self._prevBindings().copy() + binding.addNewBindings(out) + if binding not in self._seenBindings: + self._seenBindings.append(binding) + self._current = binding + log.debug(f'{INDENT*7} new binding from {self} -> {binding}') + return True return False @@ -302,14 +252,9 @@ """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all start out valid (or else raise NoOptions)""" - usedByFuncs: Set[Triple] = set() # don't worry about matching these - stmtsToResolve = list(self.graph) - for i, s in enumerate(stmtsToResolve): - if s[1] == MATH['sum']: - _, used = parseList(self.graph, s[0]) - usedByFuncs.update(used) + usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph) - stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs] + stmtsToAdd = list(self.graph - usedByFuncs) # sort them by variable dependencies; don't just try all perms! def lightSortKey(stmt): # Not this. Though it helps performance on the big rdf list cases. @@ -478,6 +423,9 @@ def graphDump(g: Union[Graph, List[Triple]]): + # this is very slow- debug only! + if not log.isEnabledFor(logging.DEBUG): + return "(skipped dump)" if not isinstance(g, Graph): g2 = Graph() g2 += g
--- a/service/mqtt_to_rdf/lhs_evaluation.py Mon Sep 13 00:18:47 2021 -0700 +++ b/service/mqtt_to_rdf/lhs_evaluation.py Mon Sep 13 01:36:06 2021 -0700 @@ -1,12 +1,15 @@ +from dataclasses import dataclass import logging from decimal import Decimal -from typing import List, Set, Tuple +from candidate_binding import CandidateBinding +from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast from prometheus_client import Summary from rdflib import RDF, Literal, Namespace, URIRef -from rdflib.term import Node +from rdflib.graph import Graph +from rdflib.term import Node, Variable -from inference_types import Triple +from inference_types import BindableTerm, Triple log = logging.getLogger('infer') @@ -46,3 +49,112 @@ cur = next return out, used + + +registeredFunctionTypes: List[Type['Function']] = [] + + +def register(cls: Type['Function']): + registeredFunctionTypes.append(cls) + return cls + + +class Function: + """any rule stmt that runs a function (not just a statement match)""" + pred: Node + + def __init__(self, stmt: Triple, ruleGraph: Graph): + self.stmt = stmt + if stmt[1] != self.pred: + raise TypeError + self.ruleGraph = ruleGraph + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + raise NotImplementedError + + def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: + out = [] + for op in self.getOperandNodes(existingBinding): + out.append(numericNode(op)) + + return out + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + """either any new bindings this function makes (could be 0), or None if it doesn't match""" + raise NotImplementedError + + def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]: + objVar = self.stmt[2] + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + return CandidateBinding({cast(BindableTerm, objVar): value}) + + +class SubjectFunction(Function): + """function that depends only on the subject term""" + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + return [existingBinding.applyTerm(self.stmt[0])] + + +class SubjectObjectFunction(Function): + """a filter function that depends on the subject and object terms""" + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])] + + +class ListFunction(Function): + """function that takes an rdf list as input""" + + def usedStatements(self) -> Set[Triple]: + _, used = parseList(self.ruleGraph, self.stmt[0]) + return used + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + operands, _ = parseList(self.ruleGraph, self.stmt[0]) + return [existingBinding.applyTerm(x) for x in operands] + + +@register +class Gt(SubjectObjectFunction): + pred = MATH['greaterThan'] + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + [x, y] = self.getNumericOperands(existingBinding) + if x > y: + return CandidateBinding({}) # no new values; just allow matching to keep going + + +@register +class AsFarenheit(SubjectFunction): + pred = ROOM['asFarenheit'] + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + [x] = self.getNumericOperands(existingBinding) + f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32)) + return self.valueInObjectTerm(f) + + +@register +class Sum(ListFunction): + pred = MATH['sum'] + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + f = Literal(sum(self.getNumericOperands(existingBinding))) + return self.valueInObjectTerm(f) + + +def functionsFor(pred: Node) -> Iterator[Type[Function]]: + for cls in registeredFunctionTypes: + if cls.pred == pred: + yield cls + +def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]: + usedByFuncs: Set[Triple] = set() # don't worry about matching these + for s in graph: + for cls in functionsFor(pred=s[1]): + if issubclass(cls, ListFunction): + usedByFuncs.update(cls(s, graph).usedStatements()) + return usedByFuncs +