view service/mqtt_to_rdf/inference.py @ 1588:0757fafbfdab

WIP inferencer - partial var and function support
author drewp@bigasterisk.com
date Thu, 02 Sep 2021 01:58:31 -0700
parents 9a3a18c494f9
children 5c1055be3c36
line wrap: on
line source

"""
copied from reasoning 2021-08-29. probably same api. should
be able to lib/ this out
"""
import itertools
import logging
from dataclasses import dataclass
from decimal import Decimal
from typing import Dict, Iterator, List, Set, Tuple, cast
from urllib.request import OpenerDirector

from prometheus_client import Summary
from rdflib import BNode, Graph, Literal, Namespace
from rdflib.collection import Collection
from rdflib.graph import ConjunctiveGraph
from rdflib.term import Node, Variable

log = logging.getLogger('infer')

Triple = Tuple[Node, Node, Node]
Rule = Tuple[Graph, Node, Graph]

READ_RULES_CALLS = Summary('read_rules_calls', 'calls')

ROOM = Namespace("http://projects.bigasterisk.com/room/")
LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
MATH = Namespace('http://www.w3.org/2000/10/swap/math#')


@dataclass
class _RuleMatch:
    """one way that a rule can match the working set"""
    vars: Dict[Variable, Node]


class Inference:

    def __init__(self) -> None:
        self.rules = ConjunctiveGraph()

    def setRules(self, g: ConjunctiveGraph):
        self.rules = g

    def infer(self, graph: Graph):
        """
        returns new graph of inferred statements.
        """
        log.info(f'Begin inference of graph len={len(graph)} with rules len={len(self.rules)}:')

        workingSet = ConjunctiveGraph()
        if isinstance(graph, ConjunctiveGraph):
            workingSet.addN(graph.quads())
        else:
            for triple in graph:
                workingSet.add(triple)

        implied = ConjunctiveGraph()

        bailout_iterations = 100
        delta = 1
        while delta > 0 and bailout_iterations > 0:
            bailout_iterations -= 1
            delta = -len(implied)
            self._iterateAllRules(workingSet, implied)
            delta += len(implied)
            log.info(f'  this inference round added {delta} more implied stmts')
        log.info(f'{len(implied)} stmts implied:')
        for st in implied:
            log.info(f'  {st}')
        return implied

    def _iterateAllRules(self, workingSet, implied):
        for r in self.rules:
            if r[1] == LOG['implies']:
                applyRule(r[0], r[2], workingSet, implied)
            else:
                log.info(f'  {r} not a rule?')


def applyRule(lhs: Graph, rhs: Graph, workingSet, implied):
    for bindings in findCandidateBindings(lhs, workingSet):
        log.debug(f' - rule gave {bindings=}')
        for newStmt in withBinding(rhs, bindings):
            workingSet.add(newStmt)
            implied.add(newStmt)


def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[Variable, Node]]:
    varsToBind: Set[Variable] = set()
    staticRuleStmts = []
    for ruleStmt in lhs:
        varsInStmt = [v for v in ruleStmt if isinstance(v, Variable)]
        varsToBind.update(varsInStmt)
        if (not varsInStmt  # ok
                and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
           ):
            staticRuleStmts.append(ruleStmt)

    if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
        log.debug('static shortcircuit')
        return

    # the total set of terms each variable could possibly match
    candidateTermMatches: Dict[Variable, Set[Node]] = findCandidateTermMatches(lhs, workingSet)

    orderedVars, orderedValueSets = organize(candidateTermMatches)

    log.debug(f'     {orderedVars=}')
    log.debug(f'{orderedValueSets=}')

    for perm in itertools.product(*orderedValueSets):
        binding: Dict[Variable, Node] = dict(zip(orderedVars, perm))
        log.debug(f'{binding=} but lets look for funcs')
        for v, val in inferredFuncBindings(lhs, binding):  # loop this until it's done
            log.debug(f'ifb tells us {v}={val}')
            binding[v] = val
        if not verifyBinding(lhs, binding, workingSet):  # fix this
            log.debug(f'verify culls')
            continue
        yield binding


def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node]]:
    for stmt in lhs:
        if stmt[1] not in inferredFuncs:
            continue
        if not isinstance(stmt[2], Variable):
            continue

        x = stmt[0]
        if isinstance(x, Variable):
            x = bindingsBefore[x]
        yield stmt[2], inferredFuncObject(x, stmt[1], lhs, bindingsBefore)


def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[Variable, Set[Node]]:
    candidateTermMatches: Dict[Variable, Set[Node]] = {}

    for r in lhs:
        for w in workingSet:
            bindingsFromStatement: Dict[Variable, Set[Node]] = {}
            for rterm, wterm in zip(r, w):
                if isinstance(rterm, Variable):
                    bindingsFromStatement.setdefault(rterm, set()).add(wterm)
                elif rterm != wterm:
                    break
            else:
                for v, vals in bindingsFromStatement.items():
                    candidateTermMatches.setdefault(v, set()).update(vals)
    return candidateTermMatches


def withBinding(rhs: Graph, bindings: Dict[Variable, Node]) -> Iterator[Triple]:
    for stmt in rhs:
        stmt = list(stmt)
        for i, t in enumerate(stmt):
            if isinstance(t, Variable):
                try:
                    stmt[i] = bindings[t]
                except KeyError:
                    # stmt is from another rule that we're not applying right now
                    break
        else:
            yield cast(Triple, stmt)


def verifyBinding(lhs: Graph, binding: Dict[Variable, Node], workingSet: Graph) -> bool:
    for stmt in withBinding(lhs, binding):
        log.debug(f'lhs verify {stmt}')
        if stmt[1] in filterFuncs:
            if not mathTest(*stmt):
                return False
        elif stmt not in workingSet and stmt[1] not in inferredFuncs:
            log.debug(f'  ver culls here')
            return False
    return True


inferredFuncs = {
    ROOM['asFarenheit'],
    MATH['sum'],
}
filterFuncs = {
    MATH['greaterThan'],
}


def inferredFuncObject(subj, pred, graph, bindings):
    if pred == ROOM['asFarenheit']:
        return Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
    elif pred == MATH['sum']:
        operands = Collection(graph, subj)
        # shouldn't be redoing this here
        operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
        log.debug(f' sum {list(operands)}')
        return Literal(sum(op.toPython() for op in operands))

    else:
        raise NotImplementedError(pred)


def mathTest(subj, pred, obj):
    x = subj.toPython()
    y = obj.toPython()
    if pred == MATH['greaterThan']:
        return x > y
    else:
        raise NotImplementedError(pred)


def organize(candidateTermMatches: Dict[Variable, Set[Node]]) -> Tuple[List[Variable], List[List[Node]]]:
    items = list(candidateTermMatches.items())
    items.sort()
    orderedVars: List[Variable] = []
    orderedValueSets: List[List[Node]] = []
    for v, vals in items:
        orderedVars.append(v)
        orderedValues: List[Node] = list(vals)
        orderedValues.sort(key=str)
        orderedValueSets.append(orderedValues)

    return orderedVars, orderedValueSets


def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
    for ruleStmt in staticRuleStmts:
        if ruleStmt not in workingSet:
            return True
    return False