changeset 1637:ec3f98d0c1d8

refactor rules eval
author drewp@bigasterisk.com
date Mon, 13 Sep 2021 01:36:06 -0700
parents 3252bdc284bc
children 0ba1625037ae
files service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/lhs_evaluation.py
diffstat 2 files changed, 137 insertions(+), 77 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Mon Sep 13 00:18:47 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Mon Sep 13 01:36:06 2021 -0700
@@ -12,11 +12,11 @@
 from prometheus_client import Histogram, Summary
 from rdflib import RDF, BNode, Graph, Namespace
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
-from rdflib.term import Literal, Node, Variable
+from rdflib.term import Node, Variable
 
 from candidate_binding import CandidateBinding
-from inference_types import (BindableTerm, BindingUnknown, ReadOnlyWorkingSet, Triple)
-from lhs_evaluation import Decimal, numericNode, parseList
+from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple
+from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs
 
 log = logging.getLogger('infer')
 INDENT = '    '
@@ -104,9 +104,6 @@
         if self._advanceWithPlainMatches(augmentedWorkingSet):
             return
 
-        if self._advanceWithBoolRules():
-            return
-
         curBind = self.prev.currentBinding() if self.prev else CandidateBinding({})
         [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False)
 
@@ -125,7 +122,7 @@
         for s in augmentedWorkingSet:
             log.debug(f'{INDENT*7} {s}')
 
-        for i, stmt in enumerate(augmentedWorkingSet):
+        for stmt in augmentedWorkingSet:
             try:
                 outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
             except Inconsistent:
@@ -140,71 +137,24 @@
                 return True
         return False
 
-    def _advanceWithBoolRules(self) -> bool:
-        log.debug(f'{INDENT*7} {self} mines bool rules')
-        if self.lhsStmt[1] == MATH['greaterThan']:
-            operands = [self.lhsStmt[0], self.lhsStmt[2]]
-            try:
-                boundOperands = self._boundOperands(operands)
-            except BindingUnknown:
-                return False
-            if numericNode(boundOperands[0]) > numericNode(boundOperands[1]):
-                binding: CandidateBinding = self._prevBindings().copy()  # no new values; just allow matching to keep going
-                if binding not in self._seenBindings:
-                    self._seenBindings.append(binding)
-                    self._current = binding
-                    log.debug(f'{INDENT*7} new binding from {self} -> {binding}')
-                    return True
-        return False
-
     def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool:
-        log.debug(f'{INDENT*7} {self} mines rules')
-
-        if self.lhsStmt[1] == ROOM['asFarenheit']:
-            pb: CandidateBinding = self._prevBindings()
-            log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
+        pred: Node = self.lhsStmt[1]
 
-            if isinstance(self.lhsStmt[0], (Variable, BNode)) and pb.contains(self.lhsStmt[0]):
-                operands = [pb.applyTerm(self.lhsStmt[0])]
-                f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32))
-                objVar = self.lhsStmt[2]
-                if not isinstance(objVar, Variable):
-                    raise TypeError(f'expected Variable, got {objVar!r}')
-                newBindings = CandidateBinding({cast(BindableTerm, objVar): cast(Node, f)})
-                self._current.addNewBindings(newBindings)
-                if newBindings not in self._seenBindings:
-                    self._seenBindings.append(newBindings)
-                    self._current = newBindings
-                    return True
-        elif self.lhsStmt[1] == MATH['sum']:
-
-            g = Graph()
-            for s in boundFullWorkingSet:
-                g.add(s)
-                log.debug(f' boundWorkingSet graph: {s}')
-            log.debug(f'_parseList subj = {lhsStmtBound[0]}')
-            operands, _ = parseList(g, lhsStmtBound[0])
-            log.debug(f'********* {INDENT*7} {self} found list {operands=}')
+        for functionType in functionsFor(pred):
+            fn = functionType(self.lhsStmt, self.parent.graph)
             try:
-                obj = Literal(sum(map(numericNode, operands)))
-            except TypeError:
-                log.debug('typeerr in operands')
+                out = fn.bind(self._prevBindings())
+            except BindingUnknown:
                 pass
             else:
-                objVar = lhsStmtBound[2]
-                log.debug(f'{objVar=}')
-
-                if not isinstance(objVar, Variable):
-                    raise TypeError(f'expected Variable, got {objVar!r}')
-                newBindings = CandidateBinding({objVar: obj})
-                log.debug(f'{newBindings=}')
-
-                self._current.addNewBindings(newBindings)
-                log.debug(f'{self._seenBindings=}')
-                if newBindings not in self._seenBindings:
-                    self._seenBindings.append(newBindings)
-                    self._current = newBindings
-                    return True
+                if out is not None:
+                    binding: CandidateBinding = self._prevBindings().copy()
+                    binding.addNewBindings(out)
+                    if binding not in self._seenBindings:
+                        self._seenBindings.append(binding)
+                        self._current = binding
+                        log.debug(f'{INDENT*7} new binding from {self} -> {binding}')
+                        return True
 
         return False
 
@@ -302,14 +252,9 @@
         """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
 
-        usedByFuncs: Set[Triple] = set()  # don't worry about matching these
-        stmtsToResolve = list(self.graph)
-        for i, s in enumerate(stmtsToResolve):
-            if s[1] == MATH['sum']:
-                _, used = parseList(self.graph, s[0])
-                usedByFuncs.update(used)
+        usedByFuncs: Set[Triple] = lhsStmtsUsedByFuncs(self.graph)
 
-        stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs]
+        stmtsToAdd = list(self.graph - usedByFuncs)
 
         # sort them by variable dependencies; don't just try all perms!
         def lightSortKey(stmt):  # Not this. Though it helps performance on the big rdf list cases.
@@ -478,6 +423,9 @@
 
 
 def graphDump(g: Union[Graph, List[Triple]]):
+    # this is very slow- debug only!
+    if not log.isEnabledFor(logging.DEBUG):
+        return "(skipped dump)"
     if not isinstance(g, Graph):
         g2 = Graph()
         g2 += g
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 13 00:18:47 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 13 01:36:06 2021 -0700
@@ -1,12 +1,15 @@
+from dataclasses import dataclass
 import logging
 from decimal import Decimal
-from typing import List, Set, Tuple
+from candidate_binding import CandidateBinding
+from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast
 
 from prometheus_client import Summary
 from rdflib import RDF, Literal, Namespace, URIRef
-from rdflib.term import Node
+from rdflib.graph import Graph
+from rdflib.term import Node, Variable
 
-from inference_types import Triple
+from inference_types import BindableTerm, Triple
 
 log = logging.getLogger('infer')
 
@@ -46,3 +49,112 @@
 
         cur = next
     return out, used
+
+
+registeredFunctionTypes: List[Type['Function']] = []
+
+
+def register(cls: Type['Function']):
+    registeredFunctionTypes.append(cls)
+    return cls
+
+
+class Function:
+    """any rule stmt that runs a function (not just a statement match)"""
+    pred: Node
+
+    def __init__(self, stmt: Triple, ruleGraph: Graph):
+        self.stmt = stmt
+        if stmt[1] != self.pred:
+            raise TypeError
+        self.ruleGraph = ruleGraph
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        raise NotImplementedError
+
+    def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]:
+        out = []
+        for op in self.getOperandNodes(existingBinding):
+            out.append(numericNode(op))
+
+        return out
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        """either any new bindings this function makes (could be 0), or None if it doesn't match"""
+        raise NotImplementedError
+
+    def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]:
+        objVar = self.stmt[2]
+        if not isinstance(objVar, Variable):
+            raise TypeError(f'expected Variable, got {objVar!r}')
+        return CandidateBinding({cast(BindableTerm, objVar): value})
+
+
+class SubjectFunction(Function):
+    """function that depends only on the subject term"""
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        return [existingBinding.applyTerm(self.stmt[0])]
+
+
+class SubjectObjectFunction(Function):
+    """a filter function that depends on the subject and object terms"""
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])]
+
+
+class ListFunction(Function):
+    """function that takes an rdf list as input"""
+
+    def usedStatements(self) -> Set[Triple]:
+        _, used = parseList(self.ruleGraph, self.stmt[0])
+        return used
+
+    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
+        operands, _ = parseList(self.ruleGraph, self.stmt[0])
+        return [existingBinding.applyTerm(x) for x in operands]
+
+
+@register
+class Gt(SubjectObjectFunction):
+    pred = MATH['greaterThan']
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        [x, y] = self.getNumericOperands(existingBinding)
+        if x > y:
+            return CandidateBinding({})  # no new values; just allow matching to keep going
+
+
+@register
+class AsFarenheit(SubjectFunction):
+    pred = ROOM['asFarenheit']
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        [x] = self.getNumericOperands(existingBinding)
+        f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32))
+        return self.valueInObjectTerm(f)
+
+
+@register
+class Sum(ListFunction):
+    pred = MATH['sum']
+
+    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
+        f = Literal(sum(self.getNumericOperands(existingBinding)))
+        return self.valueInObjectTerm(f)
+
+
+def functionsFor(pred: Node) -> Iterator[Type[Function]]:
+    for cls in registeredFunctionTypes:
+        if cls.pred == pred:
+            yield cls
+
+def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]:
+    usedByFuncs: Set[Triple] = set()  # don't worry about matching these
+    for s in graph:
+        for cls in functionsFor(pred=s[1]):
+            if issubclass(cls, ListFunction):
+                usedByFuncs.update(cls(s, graph).usedStatements())
+    return usedByFuncs
+