diff service/mqtt_to_rdf/lhs_evaluation.py @ 1640:4bb6f593ebf3

speedups: abort some rules faster
author drewp@bigasterisk.com
date Wed, 15 Sep 2021 23:56:02 -0700
parents ec3f98d0c1d8
children 3059f31b2dfa
line wrap: on
line diff
--- a/service/mqtt_to_rdf/lhs_evaluation.py	Mon Sep 13 01:54:49 2021 -0700
+++ b/service/mqtt_to_rdf/lhs_evaluation.py	Wed Sep 15 23:56:02 2021 -0700
@@ -2,7 +2,7 @@
 import logging
 from decimal import Decimal
 from candidate_binding import CandidateBinding
-from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast
+from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast
 
 from prometheus_client import Summary
 from rdflib import RDF, Literal, Namespace, URIRef
@@ -61,7 +61,7 @@
 
 class Function:
     """any rule stmt that runs a function (not just a statement match)"""
-    pred: Node
+    pred: URIRef
 
     def __init__(self, stmt: Triple, ruleGraph: Graph):
         self.stmt = stmt
@@ -144,11 +144,15 @@
         f = Literal(sum(self.getNumericOperands(existingBinding)))
         return self.valueInObjectTerm(f)
 
+### registeration is done
 
-def functionsFor(pred: Node) -> Iterator[Type[Function]]:
-    for cls in registeredFunctionTypes:
-        if cls.pred == pred:
-            yield cls
+_byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes)
+def functionsFor(pred: URIRef) -> Iterator[Type[Function]]:
+    try:
+        yield _byPred[pred]
+    except KeyError:
+        return
+
 
 def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]:
     usedByFuncs: Set[Triple] = set()  # don't worry about matching these
@@ -157,4 +161,7 @@
             if issubclass(cls, ListFunction):
                 usedByFuncs.update(cls(s, graph).usedStatements())
     return usedByFuncs
-    
+
+
+def rulePredicates() -> Set[URIRef]:
+    return set(c.pred for c in registeredFunctionTypes)
\ No newline at end of file