view service/mqtt_to_rdf/lhs_evaluation.py @ 1648:3059f31b2dfa

more performance work
author drewp@bigasterisk.com
date Fri, 17 Sep 2021 11:10:18 -0700
parents 4bb6f593ebf3
children 20474ad4968e
line wrap: on
line source

from dataclasses import dataclass
import logging
from decimal import Decimal
from candidate_binding import CandidateBinding
from typing import Dict, Iterator, List, Optional, Set, Tuple, Type, Union, cast

from prometheus_client import Summary
from rdflib import RDF, Literal, Namespace, URIRef
from rdflib.graph import Graph
from rdflib.term import BNode, Node, Variable

from inference_types import BindableTerm, Triple

log = logging.getLogger('infer')

INDENT = '    '

ROOM = Namespace("http://projects.bigasterisk.com/room/")
LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
MATH = Namespace('http://www.w3.org/2000/10/swap/math#')


def numericNode(n: Node):
    if not isinstance(n, Literal):
        raise TypeError(f'expected Literal, got {n=}')
    val = n.toPython()
    if not isinstance(val, (int, float, Decimal)):
        raise TypeError(f'expected number, got {val=}')
    return val


def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
    """"Do like Collection(g, subj) but also return all the 
    triples that are involved in the list"""
    out = []
    used = set()
    cur = subj
    while cur != RDF.nil:
        elem = graph.value(cur, RDF.first)
        if elem is None:
            raise ValueError('bad list')
        out.append(elem)
        used.add((cur, RDF.first, out[-1]))

        next = graph.value(cur, RDF.rest)
        if next is None:
            raise ValueError('bad list')
        used.add((cur, RDF.rest, next))

        cur = next
    return out, used


registeredFunctionTypes: List[Type['Function']] = []


def register(cls: Type['Function']):
    registeredFunctionTypes.append(cls)
    return cls


class Function:
    """any rule stmt that runs a function (not just a statement match)"""
    pred: URIRef

    def __init__(self, stmt: Triple, ruleGraph: Graph):
        self.stmt = stmt
        if stmt[1] != self.pred:
            raise TypeError
        self.ruleGraph = ruleGraph

    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
        raise NotImplementedError

    def getNumericOperands(self, existingBinding: CandidateBinding) -> List[Union[int, float, Decimal]]:
        out = []
        for op in self.getOperandNodes(existingBinding):
            out.append(numericNode(op))

        return out

    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
        """either any new bindings this function makes (could be 0), or None if it doesn't match"""
        raise NotImplementedError

    def valueInObjectTerm(self, value: Node) -> Optional[CandidateBinding]:
        objVar = self.stmt[2]
        if not isinstance(objVar, Variable):
            raise TypeError(f'expected Variable, got {objVar!r}')
        return CandidateBinding({cast(BindableTerm, objVar): value})

    def usedStatements(self) -> Set[Triple]:
        '''stmts in self.graph (not including self.stmt, oddly) that are part of
        this function setup and aren't to be matched literally'''
        return set()
    

class SubjectFunction(Function):
    """function that depends only on the subject term"""

    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
        return [existingBinding.applyTerm(self.stmt[0])]


class SubjectObjectFunction(Function):
    """a filter function that depends on the subject and object terms"""

    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
        return [existingBinding.applyTerm(self.stmt[0]), existingBinding.applyTerm(self.stmt[2])]


class ListFunction(Function):
    """function that takes an rdf list as input"""

    def usedStatements(self) -> Set[Triple]:
        _, used = parseList(self.ruleGraph, self.stmt[0])
        return used

    def getOperandNodes(self, existingBinding: CandidateBinding) -> List[Node]:
        operands, _ = parseList(self.ruleGraph, self.stmt[0])
        return [existingBinding.applyTerm(x) for x in operands]


@register
class Gt(SubjectObjectFunction):
    pred = MATH['greaterThan']

    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
        [x, y] = self.getNumericOperands(existingBinding)
        if x > y:
            return CandidateBinding({})  # no new values; just allow matching to keep going


@register
class AsFarenheit(SubjectFunction):
    pred = ROOM['asFarenheit']

    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
        [x] = self.getNumericOperands(existingBinding)
        f = cast(Literal, Literal(Decimal(x) * 9 / 5 + 32))
        return self.valueInObjectTerm(f)


@register
class Sum(ListFunction):
    pred = MATH['sum']

    def bind(self, existingBinding: CandidateBinding) -> Optional[CandidateBinding]:
        f = Literal(sum(self.getNumericOperands(existingBinding)))
        return self.valueInObjectTerm(f)

### registration is done

_byPred: Dict[URIRef, Type[Function]] = dict((cls.pred, cls) for cls in registeredFunctionTypes)
def functionsFor(pred: URIRef) -> Iterator[Type[Function]]:
    try:
        yield _byPred[pred]
    except KeyError:
        return


def lhsStmtsUsedByFuncs(graph: Graph) -> Set[Triple]:
    usedByFuncs: Set[Triple] = set()  # don't worry about matching these
    for s in graph:
        for cls in functionsFor(pred=s[1]):
            usedByFuncs.update(cls(s, graph).usedStatements())
    return usedByFuncs


def rulePredicates() -> Set[URIRef]:
    return set(c.pred for c in registeredFunctionTypes)