changeset 1634:ba59cfc3c747

hack math:sum in there. Test suite is passing except some slow performers
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 23:48:43 -0700
parents 6107603ed455
children 22d481f0a924
files service/mqtt_to_rdf/candidate_binding.py service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py service/mqtt_to_rdf/lhs_evaluation.py
diffstat 4 files changed, 184 insertions(+), 79 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/candidate_binding.py	Sun Sep 12 21:48:36 2021 -0700
+++ b/service/mqtt_to_rdf/candidate_binding.py	Sun Sep 12 23:48:43 2021 -0700
@@ -7,7 +7,7 @@
 from rdflib.term import Node, Variable
 
 from inference_types import BindableTerm, BindingUnknown, Triple
-log = logging.getLogger()
+log = logging.getLogger('cbind')
 INDENT = '    '
 
 @dataclass
--- a/service/mqtt_to_rdf/inference.py	Sun Sep 12 21:48:36 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Sun Sep 12 23:48:43 2021 -0700
@@ -7,16 +7,16 @@
 import time
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast
+from typing import (Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast)
 
 from prometheus_client import Histogram, Summary
-from rdflib import BNode, Graph, Namespace
+from rdflib import RDF, BNode, Graph, Namespace
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
 from rdflib.term import Literal, Node, Variable
 
 from candidate_binding import CandidateBinding
-from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
-from lhs_evaluation import Decimal, Evaluation, numericNode
+from inference_types import (BindableTerm, BindingUnknown, EvaluationFailed, ReadOnlyWorkingSet, Triple)
+from lhs_evaluation import Decimal, Evaluation, numericNode, parseList
 
 log = logging.getLogger('infer')
 INDENT = '    '
@@ -58,6 +58,7 @@
     lhsStmt: Triple
     prev: Optional['StmtLooper']
     workingSet: ReadOnlyWorkingSet
+    parent: 'Lhs'  # just for lhs.graph, really
 
     def __repr__(self):
         return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'
@@ -98,9 +99,29 @@
             augmentedWorkingSet = list(self.prev.currentBinding().apply(self._myWorkingSetMatches,
                                                                         returnBoundStatementsOnly=False))
 
-        log.debug(f'{INDENT*6} {self} has {self._myWorkingSetMatches=}')
+        log.debug(f'{INDENT*6} {self}.advance has {augmentedWorkingSet=}')
+
+        if self._advanceWithPlainMatches(augmentedWorkingSet):
+            return
+
+        if self._advanceWithBoolRules():
+            return
+
+        curBind = self.prev.currentBinding() if self.prev else CandidateBinding({})
+        [lhsStmtBound] = curBind.apply([self.lhsStmt], returnBoundStatementsOnly=False)
 
-        log.debug(f'{INDENT*6} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
+        fullWorkingSet = self.workingSet + self.parent.graph
+        boundFullWorkingSet = list(curBind.apply(fullWorkingSet, returnBoundStatementsOnly=False))
+        log.debug(f'{fullWorkingSet.__len__()=} {len(boundFullWorkingSet)=}')
+
+        if self._advanceWithFunctions(augmentedWorkingSet, boundFullWorkingSet, lhsStmtBound):
+            return
+
+        log.debug(f'{INDENT*6} {self} is past end')
+        self._pastEnd = True
+
+    def _advanceWithPlainMatches(self, augmentedWorkingSet: Sequence[Triple]) -> bool:
+        log.debug(f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
         for s in augmentedWorkingSet:
             log.debug(f'{INDENT*7} {s}')
 
@@ -111,19 +132,38 @@
                 log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
                 continue
 
-            log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}')
+            log.debug(f'{INDENT*7} {outBinding=} {self._seenBindings=}')
             if outBinding.binding not in self._seenBindings:
                 self._seenBindings.append(outBinding.binding.copy())
                 self._current = outBinding
                 log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}')
-                return
-            log.debug(f'yes we saw')
+                return True
+        return False
 
-        log.debug(f'{INDENT*6} {self} mines rules')
+    def _advanceWithBoolRules(self) -> bool:
+        log.debug(f'{INDENT*7} {self} mines bool rules')
+        if self.lhsStmt[1] == MATH['greaterThan']:
+            operands = [self.lhsStmt[0], self.lhsStmt[2]]
+            try:
+                boundOperands = self._boundOperands(operands)
+            except BindingUnknown:
+                return False
+            if numericNode(boundOperands[0]) > numericNode(boundOperands[1]):
+                bindingDict: Dict[BindableTerm,
+                                  Node] = self._prevBindings().copy()  # no new values; just allow matching to keep going
+                if bindingDict not in self._seenBindings:
+                    self._seenBindings.append(bindingDict)
+                    self._current = CandidateBinding(bindingDict)
+                    log.debug(f'{INDENT*7} new binding from {self} -> {bindingDict}')
+                    return True
+        return False
+
+    def _advanceWithFunctions(self, augmentedWorkingSet: Sequence[Triple], boundFullWorkingSet, lhsStmtBound) -> bool:
+        log.debug(f'{INDENT*7} {self} mines rules')
 
         if self.lhsStmt[1] == ROOM['asFarenheit']:
             pb: Dict[BindableTerm, Node] = self._prevBindings()
-            log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
+            log.debug(f'{INDENT*7} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')
 
             if self.lhsStmt[0] in pb:
                 operands = [pb[cast(BindableTerm, self.lhsStmt[0])]]
@@ -136,17 +176,59 @@
                 if newBindings not in self._seenBindings:
                     self._seenBindings.append(newBindings)
                     self._current = CandidateBinding(newBindings)
-                    return
+                    return True
+        elif self.lhsStmt[1] == MATH['sum']:
+
+            g = Graph()
+            for s in boundFullWorkingSet:
+                g.add(s)
+                log.debug(f' boundWorkingSet graph: {s}')
+            log.debug(f'_parseList subj = {lhsStmtBound[0]}')
+            operands, _ = parseList(g, lhsStmtBound[0])
+            log.debug(f'********* {INDENT*7} {self} found list {operands=}')
+            try:
+                obj = Literal(sum(map(numericNode, operands)))
+            except TypeError:
+                log.debug('typeerr in operands')
+                pass
+            else:
+                objVar = lhsStmtBound[2]
+                log.debug(f'{objVar=}')
 
-        log.debug(f'{INDENT*6} {self} is past end')
-        self._pastEnd = True
+                if not isinstance(objVar, Variable):
+                    raise TypeError(f'expected Variable, got {objVar!r}')
+                newBindings: Dict[BindableTerm, Node] = {objVar: obj}
+                log.debug(f'{newBindings=}')
+
+                self._current.addNewBindings(CandidateBinding(newBindings))
+                log.debug(f'{self._seenBindings=}')
+                if newBindings not in self._seenBindings:
+                    self._seenBindings.append(newBindings)
+                    self._current = CandidateBinding(newBindings)
+                    return True
+
+        return False
+
+    def _boundOperands(self, operands) -> List[Node]:
+        pb: Dict[BindableTerm, Node] = self._prevBindings()
+
+        boundOperands: List[Node] = []
+        for op in operands:
+            if isinstance(op, (Variable, BNode)):
+                if op in pb:
+                    boundOperands.append(pb[op])
+                else:
+                    raise BindingUnknown()
+            else:
+                boundOperands.append(op)
+        return boundOperands
 
     def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
         outBinding = self._prevBindings().copy()
         for rt, ct in zip(self.lhsStmt, newStmt):
             if isinstance(rt, (Variable, BNode)):
                 if rt in outBinding and outBinding[rt] != ct:
-                    raise Inconsistent()
+                    raise Inconsistent(f'{rt=} {ct=} {outBinding=}')
                 outBinding[rt] = ct
         return CandidateBinding(outBinding)
 
@@ -245,7 +327,21 @@
         """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
 
-        stmtsToAdd = list(self.graph)
+        usedByFuncs: Set[Triple] = set()  # don't worry about matching these
+        stmtsToResolve = list(self.graph)
+        for i, s in enumerate(stmtsToResolve):
+            if s[1] == MATH['sum']:
+                _, used = parseList(self.graph, s[0])
+                usedByFuncs.update(used)
+
+        stmtsToAdd = [stmt for stmt in stmtsToResolve if not stmt in usedByFuncs]
+
+        # sort them by variable dependencies; don't just try all perms!
+        def lightSortKey(stmt):  # Not this. Though it helps performance on the big rdf list cases.
+            (s, p, o) = stmt
+            return p == MATH['sum'], p, s, o
+
+        stmtsToAdd.sort(key=lightSortKey)
 
         for perm in itertools.permutations(stmtsToAdd):
             stmtStack: List[StmtLooper] = []
@@ -254,7 +350,7 @@
 
             for s in perm:
                 try:
-                    elem = StmtLooper(s, prev, knownTrue)
+                    elem = StmtLooper(s, prev, knownTrue, parent=self)
                 except NoOptions:
                     log.debug(f'{INDENT*6} permutation didnt work, try another')
                     break
@@ -540,7 +636,7 @@
         log.debug('')
         log.debug(f'{INDENT*2}-applying rule {i}')
         log.debug(f'{INDENT*3} rule def lhs:')
-        for stmt in r.lhsGraph:
+        for stmt in sorted(r.lhsGraph, reverse=True):
             log.debug(f'{INDENT*4} {stmt}')
         log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')
 
--- a/service/mqtt_to_rdf/inference_test.py	Sun Sep 12 21:48:36 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py	Sun Sep 12 23:48:43 2021 -0700
@@ -165,9 +165,9 @@
         self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt :x ."))
         self.assertGraphEqual(inf.infer(N3(":any :any :any .")), N3(":new :stmt :x ."))
 
-#     def test2(self):
-#         inf = makeInferenceWithRules("{ (2) math:sum ?x } => { :new :stmt ?x } .")
-#         self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt 2 ."))
+    def test2(self):
+        inf = makeInferenceWithRules("{ (2) math:sum ?x } => { :new :stmt ?x } .")
+        self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt 2 ."))
 
 #     @unittest.skip("too hard for now")
     # def test3(self):
@@ -175,29 +175,33 @@
     #     self.assertGraphEqual(inf.infer(N3("")), N3(":new :stmt :c ."))
 
 
-# class TestInferenceWithMathFunctions(WithGraphEqual):
+class TestInferenceWithMathFunctions(WithGraphEqual):
 
-#     def testBoolFilter(self):
-#         inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .")
-#         self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(""))
-#         self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3(""))
-#         self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 ."))
+    def testBoolFilter(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . ?x math:greaterThan 5 } => { :new :stmt ?x } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(""))
+        self.assertGraphEqual(inf.infer(N3(":a :b 5 .")), N3(""))
+        self.assertGraphEqual(inf.infer(N3(":a :b 6 .")), N3(":new :stmt 6 ."))
 
-#     def testNonFiringMathRule(self):
-#         inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
-#         self.assertGraphEqual(inf.infer(N3("")), N3(""))
+    def testNonFiringMathRule(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3("")), N3(""))
 
-#     def testStatementGeneratingRule(self):
-#         inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
-#         self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))
+    def testStatementGeneratingRule(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (?x) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 3 ."))
+
+    def test2Operands(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (?x 1) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 3 .")), N3(":new :stmt 4 ."))
 
-#     def test3Operands(self):
-#         inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .")
-#         self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 ."))
+    def test3Operands(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . (2 ?x 2) math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 6 ."))
 
-#     def test0Operands(self):
-#         inf = makeInferenceWithRules("{ :a :b ?x . () math:sum ?y } => { :new :stmt ?y } .")
-#         self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 0 ."))
+    def test0Operands(self):
+        inf = makeInferenceWithRules("{ :a :b ?x . () math:sum ?y } => { :new :stmt ?y } .")
+        self.assertGraphEqual(inf.infer(N3(":a :b 2 .")), N3(":new :stmt 0 ."))
 
 
 class TestInferenceWithCustomFunctions(WithGraphEqual):
@@ -241,11 +245,35 @@
     #     out = inf.infer(N3('[] a :MqttMessage ; :body "online" ; :topic ( "frontdoorlock" "status" ) .'))
     #     self.assertIn((EX['frontDoorLockStatus'], EX['connectedStatus'], EX['Online']), out)
 
-    # def testPerformance0(self):
+    def testPerformance0(self):
+        inf = makeInferenceWithRules('''
+            {
+              ?msg a :MqttMessage;
+                :topic :topic1;
+                :bodyFloat ?valueC .
+              ?valueC math:greaterThan -999 .
+              ?valueC room:asFarenheit ?valueF .
+            } => {
+              :airQualityIndoorTemperature :temperatureF ?valueF .
+            } .
+        ''')
+        out = inf.infer(
+            N3('''
+            <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ;
+                :body "23.9" ;
+                :bodyFloat 2.39e+01 ;
+                :topic :topic1 .
+            '''))
+
+        vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF']))
+        valueF = cast(Decimal, vlit.toPython())
+        self.assertAlmostEqual(float(valueF), 75.02)
+
+    # def testPerformance1(self):
     #     inf = makeInferenceWithRules('''
     #         {
     #           ?msg a :MqttMessage;
-    #             :topic :topic1;
+    #             :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" );
     #             :bodyFloat ?valueC .
     #           ?valueC math:greaterThan -999 .
     #           ?valueC room:asFarenheit ?valueF .
@@ -258,36 +286,12 @@
     #         <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ;
     #             :body "23.9" ;
     #             :bodyFloat 2.39e+01 ;
-    #             :topic :topic1 .
-    #         '''))
-
+    #             :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ) .
+    #     '''))
     #     vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF']))
     #     valueF = cast(Decimal, vlit.toPython())
     #     self.assertAlmostEqual(float(valueF), 75.02)
 
-#     def testPerformance1(self):
-#         inf = makeInferenceWithRules('''
-#             {
-#               ?msg a :MqttMessage;
-#                 :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" );
-#                 :bodyFloat ?valueC .
-#               ?valueC math:greaterThan -999 .
-#               ?valueC room:asFarenheit ?valueF .
-#             } => {
-#               :airQualityIndoorTemperature :temperatureF ?valueF .
-#             } .
-#         ''')
-#         out = inf.infer(
-#             N3('''
-#             <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ;
-#                 :body "23.9" ;
-#                 :bodyFloat 2.39e+01 ;
-#                 :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ) .
-#         '''))
-#         vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF']))
-#         valueF = cast(Decimal, vlit.toPython())
-#         self.assertAlmostEqual(float(valueF), 75.02)
-
     def testEmitBnodes(self):
         inf = makeInferenceWithRules('''
             { ?s a :AirQualitySensor; :label ?name . } => {
@@ -322,15 +326,15 @@
         implied = inf.infer(N3(":a :b (:e0 :e1) ."))
         self.assertGraphEqual(implied, N3(":new :stmt :here ."))
 
-    # def testList3(self):
-    #     inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2) . } => { :new :stmt :here } .")
-    #     implied = inf.infer(N3(":a :b (:e0 :e1 :e2) ."))
-    #     self.assertGraphEqual(implied, N3(":new :stmt :here ."))
+    def testList3(self):
+        inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2) . } => { :new :stmt :here } .")
+        implied = inf.infer(N3(":a :b (:e0 :e1 :e2) ."))
+        self.assertGraphEqual(implied, N3(":new :stmt :here ."))
 
-    # def testList4(self):
-    #     inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2 :e3) . } => { :new :stmt :here } .")
-    #     implied = inf.infer(N3(":a :b (:e0 :e1 :e2 :e3) ."))
-    #     self.assertGraphEqual(implied, N3(":new :stmt :here ."))
+    def testList4(self):
+        inf = makeInferenceWithRules("{ :a :b (:e0 :e1 :e2 :e3) . } => { :new :stmt :here } .")
+        implied = inf.infer(N3(":a :b (:e0 :e1 :e2 :e3) ."))
+        self.assertGraphEqual(implied, N3(":new :stmt :here ."))
 
 
 # def fakeStats():
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Sun Sep 12 21:48:36 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Sun Sep 12 23:48:43 2021 -0700
@@ -36,7 +36,7 @@
     @staticmethod
     def findEvals(graph: Graph) -> Iterator['Evaluation']:
         for stmt in graph.triples((None, MATH['sum'], None)):
-            operands, operandsStmts = _parseList(graph, stmt[0])
+            operands, operandsStmts = parseList(graph, stmt[0])
             yield Evaluation(operands, stmt, operandsStmts)
 
         for stmt in graph.triples((None, MATH['greaterThan'], None)):
@@ -98,17 +98,22 @@
     return val
 
 
-def _parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
+def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
     """"Do like Collection(g, subj) but also return all the 
     triples that are involved in the list"""
     out = []
     used = set()
     cur = subj
     while cur != RDF.nil:
-        out.append(graph.value(cur, RDF.first))
+        elem = graph.value(cur, RDF.first)
+        if elem is None:
+            raise ValueError('bad list')
+        out.append(elem)
         used.add((cur, RDF.first, out[-1]))
 
         next = graph.value(cur, RDF.rest)
+        if next is None:
+            raise ValueError('bad list')
         used.add((cur, RDF.rest, next))
 
         cur = next