Mercurial > code > home > repos > homeauto
diff service/mqtt_to_rdf/lhs_evaluation.py @ 1637:ec3f98d0c1d8
refactor rules eval
author | drewp@bigasterisk.com |
---|---|
date | Mon, 13 Sep 2021 01:36:06 -0700 |
parents | 3252bdc284bc |
children | 4bb6f593ebf3 |
line wrap: on
line diff
--- a/service/mqtt_to_rdf/lhs_evaluation.py Mon Sep 13 00:18:47 2021 -0700 +++ b/service/mqtt_to_rdf/lhs_evaluation.py Mon Sep 13 01:36:06 2021 -0700 @@ -1,12 +1,15 @@ +from dataclasses import dataclass import logging from decimal import Decimal -from typing import List, Set, Tuple +from candidate_binding import CandidateBinding +from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast from prometheus_client import Summary from rdflib import RDF, Literal, Namespace, URIRef -from rdflib.term import Node +from rdflib.graph import Graph +from rdflib.term import Node, Variable -from inference_types import Triple +from inference_types import BindableTerm, Triple log = logging.getLogger('infer') @@ -46,3 +49,112 @@ cur = next return out, used + + +registeredFunctionTypes: List[Type['Function']] = [] + + +def register(cls: Type['Function']): + registeredFunctionTypes.append(cls) + return cls + + +class Function: + """any rule stmt that runs a function (not just a statement match)""" + pred: Node + + def __init__(self, stmt: Triple, ruleGraph: Graph): + self.stmt = stmt + if stmt[1] != self.pred: + raise TypeError + self.ruleGraph = ruleGraph + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + raise NotImplementedError + + def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: + out = [] + for op in self.getOperandNodes(existingBinding): + out.append(numericNode(op)) + + return out + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + """either any new bindings this function makes (could be 0), or None if it doesn't match""" + raise NotImplementedError + + def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]: + objVar = self.stmt[2] + if not isinstance(objVar, Variable): + raise TypeError(f'expected Variable, got {objVar!r}') + return CandidateBinding({cast(BindableTerm, objVar): value}) + + +class SubjectFunction(Function): + """function that depends only on the subject term""" + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + return [existingBinding.applyTerm(self.stmt[0])] + + +class SubjectObjectFunction(Function): + """a filter function that depends on the subject and object terms""" + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])] + + +class ListFunction(Function): + """function that takes an rdf list as input""" + + def usedStatements(self) -> Set[Triple]: + _, used = parseList(self.ruleGraph, self.stmt[0]) + return used + + def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: + operands, _ = parseList(self.ruleGraph, self.stmt[0]) + return [existingBinding.applyTerm(x) for x in operands] + + +@register +class Gt(SubjectObjectFunction): + pred = MATH['greaterThan'] + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + [x, y] = self.getNumericOperands(existingBinding) + if x > y: + return CandidateBinding({}) # no new values; just allow matching to keep going + + +@register +class AsFarenheit(SubjectFunction): + pred = ROOM['asFarenheit'] + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + [x] = self.getNumericOperands(existingBinding) + f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32)) + return self.valueInObjectTerm(f) + + +@register +class Sum(ListFunction): + pred = MATH['sum'] + + def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: + f = Literal(sum(self.getNumericOperands(existingBinding))) + return self.valueInObjectTerm(f) + + +def functionsFor(pred: Node) -> Iterator[Type[Function]]: + for cls in registeredFunctionTypes: + if cls.pred == pred: + yield cls + +def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]: + usedByFuncs: Set[Triple] = set() # don't worry about matching these + for s in graph: + for cls in functionsFor(pred=s[1]): + if issubclass(cls, ListFunction): + usedByFuncs.update(cls(s, graph).usedStatements()) + return usedByFuncs +