diff service/mqtt_to_rdf/lhs_evaluation.py @ 1637:ec3f98d0c1d8

refactor rules eval
author drewp@bigasterisk.com
date Mon, 13 Sep 2021 01:36:06 -0700
parents 3252bdc284bc
children 4bb6f593ebf3
line wrap: on
line diff
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 13 00:18:47 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 13 01:36:06 2021 -0700
@@ -1,12 +1,15 @@
+from dataclasses import dataclass
 import logging
 from decimal import Decimal
-from typing import List, Set, Tuple
+from candidate_binding import CandidateBinding
+from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast
 
 from prometheus_client import Summary
 from rdflib import RDF, Literal, Namespace, URIRef
-from rdflib.term import Node
+from rdflib.graph import Graph
+from rdflib.term import Node, Variable
 
-from inference_types import Triple
+from inference_types import BindableTerm, Triple
 
 log = logging.getLogger('infer')
 
@@ -46,3 +49,112 @@
 
         cur = next
     return out, used
+
+
+registeredFunctionTypes: List[Type['Function']] = []
+
+
+def register(cls: Type['Function']):
+    registeredFunctionTypes.append(cls)
+    return cls
+
+
+class Function:
+    """any rule stmt that runs a function (not just a statement match)"""
+    pred: Node
+
+    def __init__(self, stmt: Triple, ruleGraph: Graph):
+        self.stmt = stmt
+        if stmt[1] != self.pred:
+            raise TypeError
+        self.ruleGraph = ruleGraph
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        raise NotImplementedError
+
+    def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]:
+        out = []
+        for op in self.getOperandNodes(existingBinding):
+            out.append(numericNode(op))
+
+        return out
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        """either any new bindings this function makes (could be 0), or None if it doesn't match"""
+        raise NotImplementedError
+
+    def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]:
+        objVar = self.stmt[2]
+        if not isinstance(objVar, Variable):
+            raise TypeError(f'expected Variable, got {objVar!r}')
+        return CandidateBinding({cast(BindableTerm, objVar): value})
+
+
+class SubjectFunction(Function):
+    """function that depends only on the subject term"""
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        return [existingBinding.applyTerm(self.stmt[0])]
+
+
+class SubjectObjectFunction(Function):
+    """a filter function that depends on the subject and object terms"""
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])]
+
+
+class ListFunction(Function):
+    """function that takes an rdf list as input"""
+
+    def usedStatements(self) -> Set[Triple]:
+        _, used = parseList(self.ruleGraph, self.stmt[0])
+        return used
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        operands, _ = parseList(self.ruleGraph, self.stmt[0])
+        return [existingBinding.applyTerm(x) for x in operands]
+
+
+@register
+class Gt(SubjectObjectFunction):
+    pred = MATH['greaterThan']
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        [x, y] = self.getNumericOperands(existingBinding)
+        if x > y:
+            return CandidateBinding({})  # no new values; just allow matching to keep going
+
+
+@register
+class AsFarenheit(SubjectFunction):
+    pred = ROOM['asFarenheit']
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        [x] = self.getNumericOperands(existingBinding)
+        f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32))
+        return self.valueInObjectTerm(f)
+
+
+@register
+class Sum(ListFunction):
+    pred = MATH['sum']
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        f = Literal(sum(self.getNumericOperands(existingBinding)))
+        return self.valueInObjectTerm(f)
+
+
+def functionsFor(pred: Node) -> Iterator[Type[Function]]:
+    for cls in registeredFunctionTypes:
+        if cls.pred == pred:
+            yield cls
+
+def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]:
+    usedByFuncs: Set[Triple] = set()  # don't worry about matching these
+    for s in graph:
+        for cls in functionsFor(pred=s[1]):
+            if issubclass(cls, ListFunction):
+                usedByFuncs.update(cls(s, graph).usedStatements())
+    return usedByFuncs
+