changeset 1605:449746d1598f

WIP move evaluation to new file
author drewp@bigasterisk.com
date Mon, 06 Sep 2021 01:13:55 -0700
parents e78464befd24
children 6cf39d43fd40
files service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py service/mqtt_to_rdf/lhs_evaluation.py service/mqtt_to_rdf/lhs_evaluation_test.py
diffstat 4 files changed, 163 insertions(+), 119 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Mon Sep 06 00:57:28 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Mon Sep 06 01:13:55 2021 -0700
@@ -15,6 +15,8 @@
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
 from rdflib.term import Node, Variable
 
+from lhs_evaluation import EvaluationFailed, Evaluation
+
 log = logging.getLogger('infer')
 INDENT = '    '
 
@@ -33,9 +35,6 @@
 # identifier, which can be a bottleneck.
 GRAPH_ID = URIRef('dont/care')
 
-class EvaluationFailed(ValueError):
-    """e.g. we were given (5 math:greaterThan 6)"""
-
 
 class BindingUnknown(ValueError):
     """e.g. we were asked to make the bound version 
@@ -156,7 +155,6 @@
 
         self.evaluations = list(Evaluation.findEvals(self.graph))
 
-
     def findCandidateBindings(self, workingSet: ReadOnlyWorkingSet) -> Iterator[CandidateBinding]:
         """bindings that fit the LHS of a rule, using statements from workingSet and functions
         from LHS"""
@@ -175,7 +173,6 @@
 
         log.debug(f'{INDENT*3} trying all permutations:')
 
-
         for perm in itertools.product(*orderedValueSets):
             binding = CandidateBinding(dict(zip(orderedVars, perm)))
             log.debug('')
@@ -258,80 +255,6 @@
                 log.debug(f'{INDENT*5}{val!r}')
 
 
-class Evaluation:
-    """some lhs statements need to be evaluated with a special function 
-    (e.g. math) and then not considered for the rest of the rule-firing 
-    process. It's like they already 'matched' something, so they don't need
-    to match a statement from the known-true working set.
-    
-    One Evaluation instance is for one function call.
-    """
-
-    @staticmethod
-    def findEvals(graph: Graph) -> Iterator['Evaluation']:
-        for stmt in graph.triples((None, MATH['sum'], None)):
-            operands, operandsStmts = _parseList(graph, stmt[0])
-            yield Evaluation(operands, stmt, operandsStmts)
-
-        for stmt in graph.triples((None, MATH['greaterThan'], None)):
-            yield Evaluation([stmt[0], stmt[2]], stmt, [])
-
-        for stmt in graph.triples((None, ROOM['asFarenheit'], None)):
-            yield Evaluation([stmt[0]], stmt, [])
-
-    # internal, use findEvals
-    def __init__(self, operands: List[Node], mainStmt: Triple, otherStmts: Iterable[Triple]) -> None:
-        self.operands = operands
-        self.operandsStmts = Graph(identifier=GRAPH_ID)
-        self.operandsStmts += otherStmts  # may grow
-        self.operandsStmts.add(mainStmt)
-        self.stmt = mainStmt
-
-    def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]:
-        """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?"""
-        pred = self.stmt[1]
-        objVar: Node = self.stmt[2]
-        boundOperands = []
-        for op in self.operands:
-            if isinstance(op, Variable):
-                try:
-                    op = inputBindings[op]
-                except KeyError:
-                    return {}, self.operandsStmts
-
-            boundOperands.append(op)
-
-        if pred == MATH['sum']:
-            obj = Literal(sum(map(numericNode, boundOperands)))
-            if not isinstance(objVar, Variable):
-                raise TypeError(f'expected Variable, got {objVar!r}')
-            res: Dict[BindableTerm, Node] = {objVar: obj}
-        elif pred == ROOM['asFarenheit']:
-            if len(boundOperands) != 1:
-                raise ValueError(":asFarenheit takes 1 subject operand")
-            f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32)
-            if not isinstance(objVar, Variable):
-                raise TypeError(f'expected Variable, got {objVar!r}')
-            res: Dict[BindableTerm, Node] = {objVar: f}
-        elif pred == MATH['greaterThan']:
-            if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])):
-                raise EvaluationFailed()
-            res: Dict[BindableTerm, Node] = {}
-        else:
-            raise NotImplementedError(repr(pred))
-
-        return res, self.operandsStmts
-
-
-def numericNode(n: Node):
-    if not isinstance(n, Literal):
-        raise TypeError(f'expected Literal, got {n=}')
-    val = n.toPython()
-    if not isinstance(val, (int, float, Decimal)):
-        raise TypeError(f'expected number, got {val=}')
-    return val
-
-
 class Inference:
 
     def __init__(self) -> None:
@@ -351,7 +274,7 @@
         """
         log.info(f'{INDENT*0} Begin inference of graph len={graph.__len__()} with rules len={len(self.rules)}:')
         startTime = time.time()
-        self.stats: Dict[str, Union[int,float]] = defaultdict(lambda: 0)
+        self.stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
         # everything that is true: the input graph, plus every rule conclusion we can make
         workingSet = Graph()
         workingSet += graph
@@ -410,23 +333,6 @@
             implied.add(newStmt)
 
 
-def _parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
-    """"Do like Collection(g, subj) but also return all the 
-    triples that are involved in the list"""
-    out = []
-    used = set()
-    cur = subj
-    while cur != RDF.nil:
-        out.append(graph.value(cur, RDF.first))
-        used.add((cur, RDF.first, out[-1]))
-
-        next = graph.value(cur, RDF.rest)
-        used.add((cur, RDF.rest, next))
-
-        cur = next
-    return out, used
-
-
 def graphDump(g: Union[Graph, List[Triple]]):
     if not isinstance(g, Graph):
         g2 = Graph()
--- a/service/mqtt_to_rdf/inference_test.py	Mon Sep 06 00:57:28 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py	Mon Sep 06 01:13:55 2021 -0700
@@ -6,7 +6,7 @@
 from rdflib import RDF, BNode, ConjunctiveGraph, Graph, Literal, Namespace
 from rdflib.parser import StringInputSource
 
-from inference import Inference, _parseList
+from inference import Inference
 from rdflib_debug_patches import patchBnodeCounter, patchSlimReprs
 
 patchSlimReprs()
@@ -178,27 +178,6 @@
         self.assertGraphEqual(inf.infer(N3(":a :b 12 .")), N3(":new :stmt 53.6 ."))
 
 
-class TestParseList(unittest.TestCase):
-
-    def test0Elements(self):
-        g = N3(":a :b () .")
-        bn = g.value(EX['a'], EX['b'])
-        elems, used = _parseList(g, bn)
-        self.assertEqual(elems, [])
-        self.assertFalse(used)
-
-    def test1Element(self):
-        g = N3(":a :b (0) .")
-        bn = g.value(EX['a'], EX['b'])
-        elems, used = _parseList(g, bn)
-        self.assertEqual(elems, [Literal(0)])
-        used = sorted(used)
-        self.assertEqual(used, [
-            (bn, RDF.first, Literal(0)),
-            (bn, RDF.rest, RDF.nil),
-        ])
-
-
 class TestUseCases(WithGraphEqual):
 
     def testSimpleTopic(self):
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 06 01:13:55 2021 -0700
@@ -0,0 +1,121 @@
+import logging
+from dataclasses import dataclass, field
+from decimal import Decimal
+from typing import Dict, Iterable, Iterator, List, Set, Tuple, Union, cast
+
+from prometheus_client import Summary
+from rdflib import RDF, BNode, Graph, Literal, Namespace, URIRef
+from rdflib.graph import ReadOnlyGraphAggregate
+from rdflib.term import Node, Variable
+
+log = logging.getLogger('infer')
+
+INDENT = '    '
+
+Triple = Tuple[Node, Node, Node]
+Rule = Tuple[Graph, Node, Graph]
+BindableTerm = Union[Variable, BNode]
+ReadOnlyWorkingSet = ReadOnlyGraphAggregate
+
+ROOM = Namespace("http://projects.bigasterisk.com/room/")
+LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
+MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
+
+# Graph() makes a BNode if you don't pass
+# identifier, which can be a bottleneck.
+GRAPH_ID = URIRef('dont/care')
+
+
+class EvaluationFailed(ValueError):
+    """e.g. we were given (5 math:greaterThan 6)"""
+
+
+class Evaluation:
+    """some lhs statements need to be evaluated with a special function 
+    (e.g. math) and then not considered for the rest of the rule-firing 
+    process. It's like they already 'matched' something, so they don't need
+    to match a statement from the known-true working set.
+    
+    One Evaluation instance is for one function call.
+    """
+
+    @staticmethod
+    def findEvals(graph: Graph) -> Iterator['Evaluation']:
+        for stmt in graph.triples((None, MATH['sum'], None)):
+            operands, operandsStmts = _parseList(graph, stmt[0])
+            yield Evaluation(operands, stmt, operandsStmts)
+
+        for stmt in graph.triples((None, MATH['greaterThan'], None)):
+            yield Evaluation([stmt[0], stmt[2]], stmt, [])
+
+        for stmt in graph.triples((None, ROOM['asFarenheit'], None)):
+            yield Evaluation([stmt[0]], stmt, [])
+
+    # internal, use findEvals
+    def __init__(self, operands: List[Node], mainStmt: Triple, otherStmts: Iterable[Triple]) -> None:
+        self.operands = operands
+        self.operandsStmts = Graph(identifier=GRAPH_ID)
+        self.operandsStmts += otherStmts  # may grow
+        self.operandsStmts.add(mainStmt)
+        self.stmt = mainStmt
+
+    def resultBindings(self, inputBindings) -> Tuple[Dict[BindableTerm, Node], Graph]:
+        """under the bindings so far, what would this evaluation tell us, and which stmts would be consumed from doing so?"""
+        pred = self.stmt[1]
+        objVar: Node = self.stmt[2]
+        boundOperands = []
+        for op in self.operands:
+            if isinstance(op, Variable):
+                try:
+                    op = inputBindings[op]
+                except KeyError:
+                    return {}, self.operandsStmts
+
+            boundOperands.append(op)
+
+        if pred == MATH['sum']:
+            obj = Literal(sum(map(numericNode, boundOperands)))
+            if not isinstance(objVar, Variable):
+                raise TypeError(f'expected Variable, got {objVar!r}')
+            res: Dict[BindableTerm, Node] = {objVar: obj}
+        elif pred == ROOM['asFarenheit']:
+            if len(boundOperands) != 1:
+                raise ValueError(":asFarenheit takes 1 subject operand")
+            f = Literal(Decimal(numericNode(boundOperands[0])) * 9 / 5 + 32)
+            if not isinstance(objVar, Variable):
+                raise TypeError(f'expected Variable, got {objVar!r}')
+            res: Dict[BindableTerm, Node] = {objVar: f}
+        elif pred == MATH['greaterThan']:
+            if not (numericNode(boundOperands[0]) > numericNode(boundOperands[1])):
+                raise EvaluationFailed()
+            res: Dict[BindableTerm, Node] = {}
+        else:
+            raise NotImplementedError(repr(pred))
+
+        return res, self.operandsStmts
+
+
+def numericNode(n: Node):
+    if not isinstance(n, Literal):
+        raise TypeError(f'expected Literal, got {n=}')
+    val = n.toPython()
+    if not isinstance(val, (int, float, Decimal)):
+        raise TypeError(f'expected number, got {val=}')
+    return val
+
+
+def _parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
+    """"Do like Collection(g, subj) but also return all the 
+    triples that are involved in the list"""
+    out = []
+    used = set()
+    cur = subj
+    while cur != RDF.nil:
+        out.append(graph.value(cur, RDF.first))
+        used.add((cur, RDF.first, out[-1]))
+
+        next = graph.value(cur, RDF.rest)
+        used.add((cur, RDF.rest, next))
+
+        cur = next
+    return out, used
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/service/mqtt_to_rdf/lhs_evaluation_test.py	Mon Sep 06 01:13:55 2021 -0700
@@ -0,0 +1,38 @@
+import unittest
+
+from rdflib import RDF, ConjunctiveGraph, Literal, Namespace
+from rdflib.parser import StringInputSource
+
+from lhs_evaluation import _parseList
+
+EX = Namespace('http://example.com/')
+
+
+def N3(txt: str):
+    g = ConjunctiveGraph()
+    prefix = """
+@prefix : <http://example.com/> .
+"""
+    g.parse(StringInputSource((prefix + txt).encode('utf8')), format='n3')
+    return g
+
+
+class TestParseList(unittest.TestCase):
+
+    def test0Elements(self):
+        g = N3(":a :b () .")
+        bn = g.value(EX['a'], EX['b'])
+        elems, used = _parseList(g, bn)
+        self.assertEqual(elems, [])
+        self.assertFalse(used)
+
+    def test1Element(self):
+        g = N3(":a :b (0) .")
+        bn = g.value(EX['a'], EX['b'])
+        elems, used = _parseList(g, bn)
+        self.assertEqual(elems, [Literal(0)])
+        used = sorted(used)
+        self.assertEqual(used, [
+            (bn, RDF.first, Literal(0)),
+            (bn, RDF.rest, RDF.nil),
+        ])