Mercurial > code > home > repos > homeauto
comparison service/mqtt_to_rdf/lhs_evaluation.py @ 1640:4bb6f593ebf3
speedups: abort some rules faster
author | drewp@bigasterisk.com |
---|---|
date | Wed, 15 Sep 2021 23:56:02 -0700 |
parents | ec3f98d0c1d8 |
children | 3059f31b2dfa |
comparison
equal
deleted
inserted
replaced
1639:ae5ca4ba8954 | 1640:4bb6f593ebf3 |
---|---|
1 from dataclasses import dataclass | 1 from dataclasses import dataclass |
2 import logging | 2 import logging |
3 from decimal import Decimal | 3 from decimal import Decimal |
4 from candidate_binding import CandidateBinding | 4 from candidate_binding import CandidateBinding |
5 from typing import Iterator, List, Optional, Set, Tuple, Type, Union, cast | 5 from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast |
6 | 6 |
7 from prometheus_client import Summary | 7 from prometheus_client import Summary |
8 from rdflib import RDF, Literal, Namespace, URIRef | 8 from rdflib import RDF, Literal, Namespace, URIRef |
9 from rdflib.graph import Graph | 9 from rdflib.graph import Graph |
10 from rdflib.term import Node, Variable | 10 from rdflib.term import Node, Variable |
59 return cls | 59 return cls |
60 | 60 |
61 | 61 |
62 class Function: | 62 class Function: |
63 """any rule stmt that runs a function (not just a statement match)""" | 63 """any rule stmt that runs a function (not just a statement match)""" |
64 pred: Node | 64 pred: URIRef |
65 | 65 |
66 def __init__(self, stmt: Triple, ruleGraph: Graph): | 66 def __init__(self, stmt: Triple, ruleGraph: Graph): |
67 self.stmt = stmt | 67 self.stmt = stmt |
68 if stmt[1] != self.pred: | 68 if stmt[1] != self.pred: |
69 raise TypeError | 69 raise TypeError |
142 | 142 |
143 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: | 143 def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: |
144 f = Literal(sum(self.getNumericOperands(existingBinding))) | 144 f = Literal(sum(self.getNumericOperands(existingBinding))) |
145 return self.valueInObjectTerm(f) | 145 return self.valueInObjectTerm(f) |
146 | 146 |
147 ### registeration is done | |
147 | 148 |
148 def functionsFor(pred: Node) -> Iterator[Type[Function]]: | 149 _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes) |
149 for cls in registeredFunctionTypes: | 150 def functionsFor(pred: URIRef) -> Iterator[Type[Function]]: |
150 if cls.pred == pred: | 151 try: |
151 yield cls | 152 yield _byPred[pred] |
153 except KeyError: | |
154 return | |
155 | |
152 | 156 |
153 def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]: | 157 def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]: |
154 usedByFuncs: Set[Triple] = set() # don't worry about matching these | 158 usedByFuncs: Set[Triple] = set() # don't worry about matching these |
155 for s in graph: | 159 for s in graph: |
156 for cls in functionsFor(pred=s[1]): | 160 for cls in functionsFor(pred=s[1]): |
157 if issubclass(cls, ListFunction): | 161 if issubclass(cls, ListFunction): |
158 usedByFuncs.update(cls(s, graph).usedStatements()) | 162 usedByFuncs.update(cls(s, graph).usedStatements()) |
159 return usedByFuncs | 163 return usedByFuncs |
160 | 164 |
165 | |
166 def rulePredicates() -> Set[URIRef]: | |
167 return set(c.pred for c in registeredFunctionTypes) |