changeset 1640:4bb6f593ebf3

speedups: abort some rules faster
author drewp@bigasterisk.com
date Wed, 15 Sep 2021 23:56:02 -0700
parents ae5ca4ba8954
children 5403c6343fa4
files service/mqtt_to_rdf/inference.py service/mqtt_to_rdf/inference_test.py service/mqtt_to_rdf/lhs_evaluation.py
diffstat 3 files changed, 72 insertions(+), 47 deletions(-) [+]
line wrap: on
line diff
--- a/service/mqtt_to_rdf/inference.py	Mon Sep 13 01:54:49 2021 -0700
+++ b/service/mqtt_to_rdf/inference.py	Wed Sep 15 23:56:02 2021 -0700
@@ -12,11 +12,11 @@
 from prometheus_client import Histogram, Summary
 from rdflib import RDF, BNode, Graph, Namespace
 from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
-from rdflib.term import Node, Variable
+from rdflib.term import Node, URIRef, Variable
 
 from candidate_binding import BindingConflict, CandidateBinding
 from inference_types import BindingUnknown, ReadOnlyWorkingSet, Triple
-from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs
+from lhs_evaluation import functionsFor, lhsStmtsUsedByFuncs, rulePredicates
 
 log = logging.getLogger('infer')
 INDENT = '    '
@@ -132,6 +132,8 @@
 
     def _advanceWithFunctions(self) -> bool:
         pred: Node = self.lhsStmt[1]
+        if not isinstance(pred, URIRef):
+            raise NotImplementedError
 
         for functionType in functionsFor(pred):
             fn = functionType(self.lhsStmt, self.parent.graph)
@@ -205,6 +207,10 @@
             yield BoundLhs(self, CandidateBinding({}))
             return
 
+        if self._checkPredicateCounts(knownTrue):
+            stats['_checkPredicateCountsCulls'] += 1
+            return
+
         log.debug(f'{INDENT*4} build new StmtLooper stack')
 
         try:
@@ -241,6 +247,15 @@
         for l in stmtStack:
             log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')
 
+    def _checkPredicateCounts(self, knownTrue):
+        """raise NoOptions quickly in some cases"""
+        myPreds = set(p for s, p, o in self.graph if isinstance(p, URIRef))
+        myPreds -= rulePredicates()
+        myPreds -= {RDF.first, RDF.rest}
+        if any((None, p, None) not in knownTrue for p in set(myPreds)):
+            return True
+        return False
+
     def _assembleRings(self, knownTrue: ReadOnlyWorkingSet) -> List[StmtLooper]:
         """make StmtLooper for each stmt in our LHS graph, but do it in a way that they all
         start out valid (or else raise NoOptions)"""
--- a/service/mqtt_to_rdf/inference_test.py	Mon Sep 13 01:54:49 2021 -0700
+++ b/service/mqtt_to_rdf/inference_test.py	Wed Sep 15 23:56:02 2021 -0700
@@ -229,21 +229,21 @@
         out = inf.infer(N3('[] a :MqttMessage ; :body "online" ; :topic :foo .'))
         self.assertIn((EX['frontDoorLockStatus'], EX['connectedStatus'], EX['Online']), out)
 
-    # def testTopicIsList(self):
-    #     inf = makeInferenceWithRules('''
-    #         { ?msg :body "online" . } => { ?msg :onlineTerm :Online . } .
-    #         { ?msg :body "offline" . } => { ?msg :onlineTerm :Offline . } .
+    def testTopicIsList(self):
+        inf = makeInferenceWithRules('''
+            { ?msg :body "online" . } => { ?msg :onlineTerm :Online . } .
+            { ?msg :body "offline" . } => { ?msg :onlineTerm :Offline . } .
 
-    #         {
-    #         ?msg a :MqttMessage ;
-    #             :topic ( "frontdoorlock" "status" );
-    #             :onlineTerm ?onlineness . } => {
-    #         :frontDoorLockStatus :connectedStatus ?onlineness .
-    #         } .
-    #     ''')
+            {
+            ?msg a :MqttMessage ;
+                :topic ( "frontdoorlock" "status" );
+                :onlineTerm ?onlineness . } => {
+            :frontDoorLockStatus :connectedStatus ?onlineness .
+            } .
+        ''')
 
-    #     out = inf.infer(N3('[] a :MqttMessage ; :body "online" ; :topic ( "frontdoorlock" "status" ) .'))
-    #     self.assertIn((EX['frontDoorLockStatus'], EX['connectedStatus'], EX['Online']), out)
+        out = inf.infer(N3('[] a :MqttMessage ; :body "online" ; :topic ( "frontdoorlock" "status" ) .'))
+        self.assertIn((EX['frontDoorLockStatus'], EX['connectedStatus'], EX['Online']), out)
 
     def testPerformance0(self):
         inf = makeInferenceWithRules('''
@@ -269,28 +269,28 @@
         valueF = cast(Decimal, vlit.toPython())
         self.assertAlmostEqual(float(valueF), 75.02)
 
-    # def testPerformance1(self):
-    #     inf = makeInferenceWithRules('''
-    #         {
-    #           ?msg a :MqttMessage;
-    #             :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" );
-    #             :bodyFloat ?valueC .
-    #           ?valueC math:greaterThan -999 .
-    #           ?valueC room:asFarenheit ?valueF .
-    #         } => {
-    #           :airQualityIndoorTemperature :temperatureF ?valueF .
-    #         } .
-    #     ''')
-    #     out = inf.infer(
-    #         N3('''
-    #         <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ;
-    #             :body "23.9" ;
-    #             :bodyFloat 2.39e+01 ;
-    #             :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ) .
-    #     '''))
-    #     vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF']))
-    #     valueF = cast(Decimal, vlit.toPython())
-    #     self.assertAlmostEqual(float(valueF), 75.02)
+    def testPerformance1(self):
+        inf = makeInferenceWithRules('''
+            {
+              ?msg a :MqttMessage;
+                :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" );
+                :bodyFloat ?valueC .
+              ?valueC math:greaterThan -999 .
+              ?valueC room:asFarenheit ?valueF .
+            } => {
+              :airQualityIndoorTemperature :temperatureF ?valueF .
+            } .
+        ''')
+        out = inf.infer(
+            N3('''
+            <urn:uuid:c6e1d92c-0ee1-11ec-bdbd-2a42c4691e9a> a :MqttMessage ;
+                :body "23.9" ;
+                :bodyFloat 2.39e+01 ;
+                :topic ( "air_quality_indoor" "sensor" "bme280_temperature" "state" ) .
+        '''))
+        vlit = cast(Literal, out.value(EX['airQualityIndoorTemperature'], EX['temperatureF']))
+        valueF = cast(Decimal, vlit.toPython())
+        self.assertAlmostEqual(float(valueF), 75.02)
 
     def testEmitBnodes(self):
         inf = makeInferenceWithRules('''
@@ -302,14 +302,17 @@
         out = inf.infer(N3('''
             :airQualityOutdoor a :AirQualitySensor; :label "air_quality_outdoor" .
         '''))
-        self.assertEqual(out.serialize(format='n3'), b'''@prefix ns1: <http://example.com/> .
+        out.bind('', ROOM)
+        out.bind('ex', EX)
+        self.assertEqual(out.serialize(format='n3'), b'''@prefix : <http://projects.bigasterisk.com/room/> .
+@prefix ex: <http://example.com/> .
 @prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
 @prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
 @prefix xml: <http://www.w3.org/XML/1998/namespace> .
 @prefix xsd: <http://www.w3.org/2001/XMLSchema#> .
 
-[] a ns1:MqttStatementSource ;
-    ns1:mqttTopic ( "air_quality_outdoor" "sensor" "bme280_temperature" "state" ) .
+[] a ex:MqttStatementSource ;
+    ex:mqttTopic ( "air_quality_outdoor" "sensor" "bme280_temperature" "state" ) .
 
 ''')
 
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 13 01:54:49 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Wed Sep 15 23:56:02 2021 -0700
@@ -2,7 +2,7 @@
 import logging
 from decimal import Decimal
 from candidate_binding import CandidateBinding
-from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast
+from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
 
 from prometheus_client import Summary
 from rdflib import RDF, Literal, Namespace, URIRef
@@ -61,7 +61,7 @@
 
 class Function:
     """any rule stmt that runs a function (not just a statement match)"""
-    pred: Node
+    pred: URIRef
 
     def __init__(self, stmt: Triple, ruleGraph: Graph):
         self.stmt = stmt
@@ -144,11 +144,15 @@
         f = Literal(sum(self.getNumericOperands(existingBinding)))
         return self.valueInObjectTerm(f)
 
+### registeration is done
 
-def functionsFor(pred: Node) -> Iterator[Type[Function]]:
-    for cls in registeredFunctionTypes:
-        if cls.pred == pred:
-            yield cls
+_byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes)
+def functionsFor(pred: URIRef) -> Iterator[Type[Function]]:
+    try:
+        yield _byPred[pred]
+    except KeyError:
+        return
+
 
 def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]:
     usedByFuncs: Set[Triple] = set()  # don't worry about matching these
@@ -157,4 +161,7 @@
             if issubclass(cls, ListFunction):
                 usedByFuncs.update(cls(s, graph).usedStatements())
     return usedByFuncs
-    
+
+
+def rulePredicates() -> Set[URIRef]:
+    return set(c.pred for c in registeredFunctionTypes)
\ No newline at end of file