view service/mqtt_to_rdf/inference.py @ 1591:668958454ae2

shuffle some logging
author drewp@bigasterisk.com
date Sat, 04 Sep 2021 15:34:29 -0700
parents 327202020892
children d7b66234064b
line wrap: on
line source

"""
copied from reasoning 2021-08-29. probably same api. should
be able to lib/ this out
"""
from collections import defaultdict
import itertools
import logging
from dataclasses import dataclass
from decimal import Decimal
from typing import Dict, Iterator, List, Set, Tuple, Union, cast
from urllib.request import OpenerDirector

from prometheus_client import Summary
from rdflib import BNode, Graph, Literal, Namespace, URIRef, RDF
from rdflib.collection import Collection
from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
from rdflib.term import Node, Variable

log = logging.getLogger('infer')

Triple = Tuple[Node, Node, Node]
Rule = Tuple[Graph, Node, Graph]
BindableTerm = Union[Variable, BNode]

READ_RULES_CALLS = Summary('read_rules_calls', 'calls')

ROOM = Namespace("http://projects.bigasterisk.com/room/")
LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
MATH = Namespace('http://www.w3.org/2000/10/swap/math#')


@dataclass
class _RuleMatch:
    """one way that a rule can match the working set"""
    vars: Dict[Variable, Node]


class Inference:

    def __init__(self) -> None:
        self.rules = ConjunctiveGraph()

    def setRules(self, g: ConjunctiveGraph):
        self.rules = g

    def infer(self, graph: Graph):
        """
        returns new graph of inferred statements.
        """
        log.info(f'Begin inference of graph len={len(graph)} with rules len={len(self.rules)}:')

        # everything that is true: the input graph, plus every rule conclusion we can make
        workingSet = graphCopy(graph)

        # just the statements that came from rule RHS's.
        implied = ConjunctiveGraph()

        bailout_iterations = 100
        delta = 1
        while delta > 0 and bailout_iterations > 0:
            log.debug(f'  * iteration ({bailout_iterations} left)')
            bailout_iterations -= 1
            delta = -len(implied)
            self._iterateAllRules(workingSet, implied)
            delta += len(implied)
            log.info(f'  this inference round added {delta} more implied stmts')
        log.info(f'    {len(implied)} stmts implied:')
        for st in implied:
            log.info(f'        {st}')
        return implied

    def _iterateAllRules(self, workingSet, implied):
        for i, r in enumerate(self.rules):
            log.debug(f'      workingSet: {graphDump(workingSet)}')
            log.debug(f'      - applying rule {i}')
            log.debug(f'        lhs: {graphDump(r[0])}')
            log.debug(f'        rhs: {graphDump(r[2])}')
            if r[1] == LOG['implies']:
                applyRule(r[0], r[2], workingSet, implied)
            else:
                log.info(f'   {r} not a rule?')


def graphCopy(src: Graph) -> Graph:
    if isinstance(src, ConjunctiveGraph):
        out = ConjunctiveGraph()
        out.addN(src.quads())
        return out
    else:
        out = Graph()
        for triple in src:
            out.add(triple)
        return out


def graphDump(g: Union[Graph, List[Triple]]):
    if not isinstance(g, Graph):
        g2 = Graph()
        for stmt in g:
            g2.add(stmt)
        g = g2
    g.bind('', ROOM)
    g.bind('ex', Namespace('http://example.com/'))
    lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
    lines = [line for line in lines if not line.startswith('@prefix')]
    return ' '.join(lines)


def applyRule(lhs: Graph, rhs: Graph, workingSet: Graph, implied: Graph):
    for bindings in findCandidateBindings(lhs, workingSet):
        log.debug(f'        rule gave {bindings=}')
        for lhsBoundStmt in withBinding(lhs, bindings):
            workingSet.add(lhsBoundStmt)
        for newStmt in withBinding(rhs, bindings):
            workingSet.add(newStmt)
            implied.add(newStmt)


def findCandidateBindings(lhs: Graph, workingSet: Graph) -> Iterator[Dict[BindableTerm, Node]]:
    """bindings that fit the LHS of a rule, using statements from workingSet and functions
    from LHS"""
    varsToBind: Set[BindableTerm] = set()
    staticRuleStmts = Graph()
    for ruleStmt in lhs:
        varsInStmt = [v for v in ruleStmt if isinstance(v, (Variable, BNode))]
        varsToBind.update(varsInStmt)
        if (not varsInStmt  # ok
                #and not any(isinstance(t, BNode) for t in ruleStmt)  # approx
           ):
            staticRuleStmts.add(ruleStmt)

    log.debug(f'        varsToBind: {sorted(varsToBind)}')

    if someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
        log.debug(f'    someStaticStmtDoesntMatch: {graphDump(staticRuleStmts)}')
        return

    # the total set of terms each variable could possibly match
    candidateTermMatches: Dict[BindableTerm, Set[Node]] = findCandidateTermMatches(lhs, workingSet)

    orderedVars, orderedValueSets = organize(candidateTermMatches)

    log.debug(f'        candidate terms:')
    log.debug(f'            {orderedVars=}')
    log.debug(f'            {orderedValueSets=}')

    for i, perm in enumerate(itertools.product(*orderedValueSets)):
        binding: Dict[BindableTerm, Node] = dict(zip(orderedVars, perm))
        log.debug('')
        log.debug(f'            ** trying {binding=}')
        usedByFuncs = Graph()
        for v, val, used in inferredFuncBindings(lhs, binding):  # loop this until it's done
            log.debug(f'            inferredFuncBindings tells us {v}={val}')
            binding[v] = val
            usedByFuncs += used
        if len(binding) != len(varsToBind):
            log.debug(f'                binding is incomplete, needs {varsToBind}')

            continue
        if not verifyBinding(lhs, binding, workingSet, usedByFuncs):  # fix this
            log.debug(f'            this binding did not verify')
            continue
        yield binding


def inferredFuncBindings(lhs: Graph, bindingsBefore) -> Iterator[Tuple[Variable, Node, Graph]]:
    for stmt in lhs:
        if stmt[1] not in inferredFuncs:
            continue
        var = stmt[2]
        if not isinstance(var, Variable):
            continue

        x = stmt[0]
        if isinstance(x, Variable):
            x = bindingsBefore[x]

        resultObject, usedByFunc = inferredFuncObject(x, stmt[1], lhs, bindingsBefore)

        yield var, resultObject, usedByFunc


def findCandidateTermMatches(lhs: Graph, workingSet: Graph) -> Dict[BindableTerm, Set[Node]]:
    candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
    lhsBnodes: Set[BNode] = set()
    for lhsStmt in lhs:
        for trueStmt in workingSet:
            log.debug(f'            lhsStmt={graphDump([lhsStmt])} trueStmt={graphDump([trueStmt])}')
            bindingsFromStatement: Dict[Variable, Set[Node]] = {}
            for lhsTerm, trueTerm in zip(lhsStmt, trueStmt):
                # log.debug(f' test {lhsTerm=} {trueTerm=}')
                if isinstance(lhsTerm, BNode):
                    lhsBnodes.add(lhsTerm)
                elif isinstance(lhsTerm, Variable):
                    bindingsFromStatement.setdefault(lhsTerm, set()).add(trueTerm)
                elif lhsTerm != trueTerm:
                    break
            else:
                for v, vals in bindingsFromStatement.items():
                    candidateTermMatches[v].update(vals)

    for trueStmt in itertools.chain(workingSet, lhs):
        for b in lhsBnodes:
            for t in [trueStmt[0], trueStmt[2]]:
                if isinstance(t, (URIRef, BNode)):
                    candidateTermMatches[b].add(t)
    return candidateTermMatches


def withBinding(toBind: Graph, bindings: Dict[BindableTerm, Node], includeStaticStmts=True) -> Iterator[Triple]:
    for stmt in toBind:
        stmt = list(stmt)
        static = True
        for i, term in enumerate(stmt):
            if isinstance(term, (Variable, BNode)):
                stmt[i] = bindings[term]
                static = False
        else:
            if includeStaticStmts or not static:
                yield cast(Triple, stmt)


def verifyBinding(lhs: Graph, binding: Dict[BindableTerm, Node], workingSet: Graph, usedByFuncs: Graph) -> bool:
    """Can this lhs be true all at once in workingSet? Does it match with these bindings?"""
    boundLhs = list(withBinding(lhs, binding))
    boundUsedByFuncs = list(withBinding(usedByFuncs, binding))
    if log.isEnabledFor(logging.DEBUG):
        log.debug(f'                verify all bindings against this lhs:')
        for stmt in boundLhs:
            log.debug(f'                    {stmt}')

        log.debug(f'                and against this workingSet:')
        for stmt in workingSet:
            log.debug(f'                    {stmt}')

        log.debug(f'                ignoring these usedByFuncs:')
        for stmt in boundUsedByFuncs:
            log.debug(f'                    {stmt}')
    # The static stmts in lhs are obviously going
    # to match- we only need to verify the ones
    # that needed bindings.
    for stmt in boundLhs:  #withBinding(lhs, binding, includeStaticStmts=False):
        log.debug(f'                check for {stmt}')

        if stmt[1] in filterFuncs:
            if not mathTest(*stmt):
                log.debug(f'                    binding was invalid because {stmt}) is not true')
                return False
        elif stmt in boundUsedByFuncs:
            pass
        elif stmt in workingSet:
            pass
        else:
            log.debug(f'                    binding was invalid because {stmt}) cannot be true')
            return False
    return True


inferredFuncs = {
    ROOM['asFarenheit'],
    MATH['sum'],
}
filterFuncs = {
    MATH['greaterThan'],
}


def isStatic(spo: Triple):
    for t in spo:
        if isinstance(t, (Variable, BNode)):
            return False
    return True


def inferredFuncObject(subj, pred, graph, bindings) -> Tuple[Literal, Graph]:
    """return result from like `(1 2) math:sum ?out .` plus a graph of all the
    statements involved in that function rule (including the bound answer"""
    used = Graph()
    if pred == ROOM['asFarenheit']:
        obj = Literal(Decimal(subj.toPython()) * 9 / 5 + 32)
    elif pred == MATH['sum']:
        operands, operandsStmts = parseList(graph, subj)
        # shouldn't be redoing this here
        operands = [bindings[o] if isinstance(o, Variable) else o for o in operands]
        log.debug(f'                sum {[op.toPython() for op in operands]}')
        used += operandsStmts
        obj = Literal(sum(op.toPython() for op in operands))
    else:
        raise NotImplementedError(pred)

    used.add((subj, pred, obj))
    return obj, used


def parseList(graph, subj) -> Tuple[List[Node], Set[Triple]]:
    out = []
    used = set()
    cur = subj
    while True:
        # bug: mishandles empty list
        out.append(graph.value(cur, RDF.first))
        used.add((cur, RDF.first, out[-1]))

        next = graph.value(cur, RDF.rest)
        used.add((cur, RDF.rest, next))
        cur = next
        if cur == RDF.nil:
            break
    return out, used


def mathTest(subj, pred, obj):
    x = subj.toPython()
    y = obj.toPython()
    if pred == MATH['greaterThan']:
        return x > y
    else:
        raise NotImplementedError(pred)


def organize(candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Tuple[List[BindableTerm], List[List[Node]]]:
    items = list(candidateTermMatches.items())
    items.sort()
    orderedVars: List[BindableTerm] = []
    orderedValueSets: List[List[Node]] = []
    for v, vals in items:
        orderedVars.append(v)
        orderedValues: List[Node] = list(vals)
        orderedValues.sort(key=str)
        orderedValueSets.append(orderedValues)

    return orderedVars, orderedValueSets


def someStaticStmtDoesntMatch(staticRuleStmts, workingSet):
    for ruleStmt in staticRuleStmts:
        if ruleStmt not in workingSet:
            log.debug(f'            {ruleStmt} not in working set- skip rule')

            return True
    return False