Mercurial > code > home > repos > homeauto
view service/mqtt_to_rdf/lhs_evaluation.py @ 1660:31f7dab6a60b
function evaluation uses Chunk lists now and runs fast. Only a few edge cases still broken
author | drewp@bigasterisk.com |
---|---|
date | Sun, 19 Sep 2021 15:39:37 -0700 |
parents | 7ec2483d61b5 |
children | 00a5624d1d14 |
line wrap: on
line source
import logging from decimal import Decimal from typing import (Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast) from prometheus_client import Summary from rdflib import RDF, Literal, Namespace, URIRef from rdflib.term import Node, Variable from candidate_binding import CandidateBinding from inference_types import BindableTerm, Triple from stmt_chunk import Chunk, ChunkedGraph log = logging.getLogger('infer') INDENT = ' ' ROOM = Namespace("http://projects.bigasterisk.com/room/") LOG = Namespace('http://www.w3.org/2000/10/swap/log#') MATH = Namespace('http://www.w3.org/2000/10/swap/math#') def _numericNode(n: Node): if not isinstance(n, Literal): raise TypeError(f'expected Literal, got {n=}') val = n.toPython() if not isinstance(val, (int, float, Decimal)): raise TypeError(f'expected number, got {val=}') return val class Function: """any rule stmt that runs a function (not just a statement match)""" pred: URIRef def __init__(self, chunk: Chunk, ruleGraph: ChunkedGraph): self.chunk = chunk if chunk.predicate != self.pred: raise TypeError self.ruleGraph = ruleGraph def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: raise NotImplementedError def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]: out = [] for op in self.getOperandNodes(existingBinding): out.append(_numericNode(op)) return out def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]: """either any new bindings this function makes (could be 0), or None if it doesn't match""" raise NotImplementedError def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]: objVar = self.chunk.primary[2] if not isinstance(objVar, Variable): raise TypeError(f'expected Variable, got {objVar!r}') return CandidateBinding({cast(BindableTerm, objVar): value}) class SubjectFunction(Function): """function that depends only on the subject term""" def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: if self.chunk.primary[0] is None: raise ValueError(f'expected one operand on {self.chunk}') return [existingBinding.applyTerm(self.chunk.primary[0])] class SubjectObjectFunction(Function): """a filter function that depends on the subject and object terms""" def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: if self.chunk.primary[0] is None or self.chunk.primary[2] is None: raise ValueError(f'expected one operand on each side of {self.chunk}') return [existingBinding.applyTerm(self.chunk.primary[0]), existingBinding.applyTerm(self.chunk.primary[2])] class ListFunction(Function): """function that takes an rdf list as input""" def usedStatements(self) -> Set[Triple]: raise NotImplementedError if self.chunk.subjist is None: raise ValueError(f'expected subject list on {self.chunk}') _, used = _parseList(self.ruleGraph, self.chunk.primary[0]) return used def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]: if self.chunk.subjList is None: raise ValueError(f'expected subject list on {self.chunk}') return [existingBinding.applyTerm(x) for x in self.chunk.subjList] _registeredFunctionTypes: List[Type['Function']] = [] def register(cls: Type['Function']): _registeredFunctionTypes.append(cls) return cls import inference_functions # calls register() on some classes _byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in _registeredFunctionTypes) def functionsFor(pred: URIRef) -> Iterator[Type[Function]]: try: yield _byPred[pred] except KeyError: return # def lhsStmtsUsedByFuncs(graph: ChunkedGraph) -> Set[Chunk]: # usedByFuncs: Set[Triple] = set() # don't worry about matching these # for s in graph: # for cls in functionsFor(pred=s[1]): # usedByFuncs.update(cls(s, graph).usedStatements()) # return usedByFuncs def rulePredicates() -> Set[URIRef]: return set(c.pred for c in _registeredFunctionTypes)