view service/mqtt_to_rdf/inference.py @ 1632:bd79a2941cab

just (a lot of) debug changes
author drewp@bigasterisk.com
date Sun, 12 Sep 2021 21:46:39 -0700
parents 2c85a4f5dd9c
children 6107603ed455
line wrap: on
line source

"""
copied from reasoning 2021-08-29. probably same api. should
be able to lib/ this out
"""
import itertools
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast

from prometheus_client import Histogram, Summary
from rdflib import BNode, Graph, Namespace
from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
from rdflib.term import Literal, Node, Variable

from candidate_binding import CandidateBinding
from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
from lhs_evaluation import Decimal, Evaluation, numericNode

log = logging.getLogger('infer')
INDENT = '    '

INFER_CALLS = Summary('inference_infer_calls', 'calls')
INFER_GRAPH_SIZE = Histogram('inference_graph_size', 'statements', buckets=[2**x for x in range(2, 20, 2)])

ROOM = Namespace("http://projects.bigasterisk.com/room/")
LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
MATH = Namespace('http://www.w3.org/2000/10/swap/math#')


def stmtTemplate(stmt: Triple) -> Tuple[Optional[Node], Optional[Node], Optional[Node]]:
    return (
        None if isinstance(stmt[0], (Variable, BNode)) else stmt[0],
        None if isinstance(stmt[1], (Variable, BNode)) else stmt[1],
        None if isinstance(stmt[2], (Variable, BNode)) else stmt[2],
    )


class NoOptions(ValueError):
    """stmtlooper has no possibilites to add to the binding; the whole rule must therefore not apply"""


class Inconsistent(ValueError):
    """adding this stmt would be inconsistent with an existing binding"""


_stmtLooperShortId = itertools.count()


@dataclass
class StmtLooper:
    """given one LHS stmt, iterate through the possible matches for it,
    returning what bindings they would imply. Only distinct bindings are
    returned. The bindings build on any `prev` StmtLooper's results.

    This iterator is restartable."""
    lhsStmt: Triple
    prev: Optional['StmtLooper']
    workingSet: ReadOnlyWorkingSet

    def __repr__(self):
        return f'StmtLooper{self._shortId}({graphDump([self.lhsStmt])} {"<pastEnd>" if self.pastEnd() else ""})'

    def __post_init__(self):
        self._shortId = next(_stmtLooperShortId)
        self._myWorkingSetMatches = self._myMatches(self.workingSet)

        self._current = CandidateBinding({})
        self._pastEnd = False
        self._seenBindings: List[Dict[BindableTerm, Node]] = []
        self.restart()

    def _myMatches(self, g: Graph) -> List[Triple]:
        template = stmtTemplate(self.lhsStmt)

        stmts = sorted(cast(Iterator[Triple], list(g.triples(template))))
        # plus new lhs possibilties...
        # log.debug(f'{INDENT*6} {self} find {len(stmts)=} in {len(self.workingSet)=}')

        return stmts

    def _prevBindings(self) -> Dict[BindableTerm, Node]:
        if not self.prev or self.prev.pastEnd():
            return {}

        return self.prev.currentBinding().binding

    def advance(self):
        """update to a new set of bindings we haven't seen (since last restart), or go into pastEnd mode"""
        log.debug(f'{INDENT*6} {self} mines {len(self._myWorkingSetMatches)} matching statements')
        for i, stmt in enumerate(self._myWorkingSetMatches):
            try:
                outBinding = self._totalBindingIfThisStmtWereTrue(stmt)
            except Inconsistent:
                log.debug(f'{INDENT*7} {self} - {stmt} would be inconsistent with prev bindings')
                continue

            log.debug(f'{INDENT*6} {outBinding=} {self._seenBindings=}')
            if outBinding.binding not in self._seenBindings:
                self._seenBindings.append(outBinding.binding.copy())
                self._current = outBinding
                log.debug(f'{INDENT*7} new binding from {self} -> {outBinding}')
                return
            log.debug(f'yes we saw')

        log.debug(f'{INDENT*6} {self} mines rules')

        if self.lhsStmt[1] == ROOM['asFarenheit']:
            pb: Dict[BindableTerm, Node] = self._prevBindings()
            log.debug(f'{INDENT*6} {self} consider ?x faren ?y where ?x={self.lhsStmt[0]} and {pb=}')

            if self.lhsStmt[0] in pb:
                operands = [pb[cast(BindableTerm, self.lhsStmt[0])]]
                f = cast(Literal, Literal(Decimal(numericNode(operands[0])) * 9 / 5 + 32))
                objVar = self.lhsStmt[2]
                if not isinstance(objVar, Variable):
                    raise TypeError(f'expected Variable, got {objVar!r}')
                newBindings = {cast(BindableTerm, objVar): cast(Node, f)}
                self._current.addNewBindings(CandidateBinding(newBindings))
                if newBindings not in self._seenBindings:
                    self._seenBindings.append(newBindings)
                    self._current = CandidateBinding(newBindings)

        log.debug(f'{INDENT*6} {self} is past end')
        self._pastEnd = True

    def _totalBindingIfThisStmtWereTrue(self, newStmt: Triple) -> CandidateBinding:
        outBinding = self._prevBindings().copy()
        for rt, ct in zip(self.lhsStmt, newStmt):
            if isinstance(rt, (Variable, BNode)):
                if rt in outBinding and outBinding[rt] != ct:
                    raise Inconsistent()
                outBinding[rt] = ct
        return CandidateBinding(outBinding)

    def currentBinding(self) -> CandidateBinding:
        if self.pastEnd():
            raise NotImplementedError()
        return self._current

    def newLhsStmts(self) -> List[Triple]:
        """under the curent bindings, what new stmts beyond workingSet are also true? includes all `prev`"""
        return []

    def pastEnd(self) -> bool:
        return self._pastEnd

    def restart(self):
        self._pastEnd = False
        self._seenBindings = []
        self.advance()
        if self.pastEnd():
            raise NoOptions()


@dataclass
class Lhs:
    graph: Graph

    def __post_init__(self):
        # do precomputation in here that's not specific to the workingSet
        # self.staticRuleStmts = Graph()
        # self.nonStaticRuleStmts = Graph()

        # self.lhsBindables: Set[BindableTerm] = set()
        # self.lhsBnodes: Set[BNode] = set()
        # for ruleStmt in self.graph:
        #     varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))]
        #     self.lhsBindables.update(varsAndBnodesInStmt)
        #     self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode))
        #     if not varsAndBnodesInStmt:
        #         self.staticRuleStmts.add(ruleStmt)
        #     else:
        #         self.nonStaticRuleStmts.add(ruleStmt)

        # self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts)

        self.evaluations = list(Evaluation.findEvals(self.graph))

    def __repr__(self):
        return f"Lhs({graphDump(self.graph)})"

    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
        """bindings that fit the LHS of a rule, using statements from workingSet and functions
        from LHS"""
        log.debug(f'{INDENT*4} build new StmtLooper stack')

        stmtStack: List[StmtLooper] = []
        try:
            prev: Optional[StmtLooper] = None
            for s in sorted(self.graph):  # order of this matters! :(
                stmtStack.append(StmtLooper(s, prev, knownTrue))
                prev = stmtStack[-1]
        except NoOptions:
            log.debug(f'{INDENT*5} start up with no options; 0 bindings')
            return
        self._debugStmtStack('initial odometer', stmtStack)


        if any(ring.pastEnd() for ring in stmtStack):
            log.debug(f'{INDENT*5} some rings started at pastEnd {stmtStack}')

            raise NoOptions()
        sl = stmtStack[-1]
        iterCount = 0
        while True:
            iterCount += 1
            if iterCount > 10:
                raise ValueError('stuck')

            log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')

            yield BoundLhs(self, sl.currentBinding())

            self._debugStmtStack('odometer', stmtStack)

            done = self._advanceAll(stmtStack)

            self._debugStmtStack('odometer after ({done=})', stmtStack)

            log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
            if done:
                break

    def _debugStmtStack(self, label, stmtStack):
        log.debug(f'{INDENT*5} {label}:')
        for l in stmtStack:
            log.debug(f'{INDENT*6} {l} curbind={l.currentBinding() if not l.pastEnd() else "<end>"}')

    def _advanceAll(self, stmtStack: List[StmtLooper]) -> bool:
        carry = True  # 1st elem always must advance
        for i, ring in enumerate(stmtStack):
            # unlike normal odometer, advancing any earlier ring could invalidate later ones
            if carry:
                log.debug(f'{INDENT*5} advanceAll [{i}] {ring} carry/advance')
                ring.advance()
                carry = False
            if ring.pastEnd():
                if ring is stmtStack[-1]:
                    log.debug(f'{INDENT*5} advanceAll [{i}] {ring} says we done')
                    return True
                log.debug(f'{INDENT*5} advanceAll [{i}] {ring} restart')
                ring.restart()
                carry = True
        return False

    def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool:
        # bug: see TestSelfFulfillingRule.test3 for a case where this rule's
        # static stmt is matched by a non-static stmt in the rule itself
        for ruleStmt in self.staticRuleStmts:
            if ruleStmt not in knownTrue:
                log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule')
                return False
        return True

    def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']:
        """this yields at least the working bindings, and possibly others"""
        candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet)
        for bindRow in self._product(candidateTermMatches):
            try:
                yield BoundLhs(self, bindRow)
            except EvaluationFailed:
                stats['permCountFailingEval'] += 1

    def _allCandidateTermMatches(self, workingSet: ReadOnlyWorkingSet) -> Dict[BindableTerm, Set[Node]]:
        """the total set of terms each variable could possibly match"""

        candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
        for lhsStmt in self.graph:
            log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}')
            for i, trueStmt in enumerate(workingSet):
                # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}')

                for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt):
                    candidateTermMatches[v].update(vals)

        return candidateTermMatches

    def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]:
        """if these stmts match otherwise, what BNode or Variable mappings do we learn?
        
        e.g. stmt1=(?x B ?y) and stmt2=(A B C), then we yield (?x, {A}) and (?y, {C})
        or   stmt1=(_:x B C) and stmt2=(A B C), then we yield (_:x, {A})
        or   stmt1=(?x B C)  and stmt2=(A B D), then we yield nothing
        """
        bindingsFromStatement = {}
        for term1, term2 in zip(stmt1, stmt2):
            if isinstance(term1, (BNode, Variable)):
                bindingsFromStatement.setdefault(term1, set()).add(term2)
            elif term1 != term2:
                break
        else:
            for v, vals in bindingsFromStatement.items():
                log.debug(f'{INDENT*5} {v=} {vals=}')
                yield v, vals

    def _product(self, candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Iterator[CandidateBinding]:
        orderedVars, orderedValueSets = _organize(candidateTermMatches)
        self._logCandidates(orderedVars, orderedValueSets)
        log.debug(f'{INDENT*3} trying all permutations:')
        if not orderedValueSets:
            yield CandidateBinding({})
            return

        if not orderedValueSets or not all(orderedValueSets):
            # some var or bnode has no options at all
            return
        rings: List[Iterator[Node]] = [itertools.cycle(valSet) for valSet in orderedValueSets]
        currentSet: List[Node] = [next(ring) for ring in rings]
        starts = [valSet[-1] for valSet in orderedValueSets]
        while True:
            for col, curr in enumerate(currentSet):
                currentSet[col] = next(rings[col])
                log.debug(f'{INDENT*4} currentSet: {repr(currentSet)}')
                yield CandidateBinding(dict(zip(orderedVars, currentSet)))
                if curr is not starts[col]:
                    break
                if col == len(orderedValueSets) - 1:
                    return

    def _logCandidates(self, orderedVars, orderedValueSets):
        if not log.isEnabledFor(logging.DEBUG):
            return
        log.debug(f'{INDENT*3} resulting candidate terms:')
        for v, vals in zip(orderedVars, orderedValueSets):
            log.debug(f'{INDENT*4} {v!r} could be:')
            for val in vals:
                log.debug(f'{INDENT*5}{val!r}')


@dataclass
class BoundLhs:
    lhs: Lhs
    binding: CandidateBinding

    def __post_init__(self):
        self.usedByFuncs = Graph()
        # self._applyFunctions()

    def lhsStmtsWithoutEvals(self):
        for stmt in self.lhs.graph:
            if stmt in self.usedByFuncs:
                continue
            yield stmt

    def _applyFunctions(self):
        """may grow the binding with some results"""
        while True:
            delta = self._applyFunctionsIteration()
            if delta == 0:
                break

    def _applyFunctionsIteration(self):
        before = len(self.binding.binding)
        delta = 0
        for ev in self.lhs.evaluations:
            newBindings, usedGraph = ev.resultBindings(self.binding)
            self.usedByFuncs += usedGraph
            self.binding.addNewBindings(newBindings)
        delta = len(self.binding.binding) - before
        log.debug(f'{INDENT*4} eval rules made {delta} new bindings')
        return delta

    def verify(self, workingSet: ReadOnlyWorkingSet) -> bool:
        """Can this bound lhs be true all at once in workingSet?"""
        rem = cast(Set[Triple], self.lhs.nonStaticRuleStmtsSet.difference(self.usedByFuncs))
        boundLhs = self.binding.apply(rem)

        if log.isEnabledFor(logging.DEBUG):
            boundLhs = list(boundLhs)
            self._logVerifyBanner(boundLhs, workingSet)

        for stmt in boundLhs:
            log.debug(f'{INDENT*4} check for %s', stmt)

            if stmt not in workingSet:
                log.debug(f'{INDENT*5} stmt not known to be true')
                return False
        return True

    def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet):
        log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:')
        for stmt in sorted(boundLhs):
            log.debug(f'{INDENT*4}|{INDENT} {stmt}')

        # log.debug(f'{INDENT*4}| and against this workingSet:')
        # for stmt in sorted(workingSet):
        #     log.debug(f'{INDENT*4}|{INDENT} {stmt}')

        log.debug(f'{INDENT*4}\\')


@dataclass
class Rule:
    lhsGraph: Graph
    rhsGraph: Graph

    def __post_init__(self):
        self.lhs = Lhs(self.lhsGraph)
        #
        self.rhsBnodeMap = {}

    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict):
        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
            log.debug(f'{INDENT*5} +rule has a working binding: {bound}')

            # rhs could have more bnodes, and they just need to be distinct per rule-firing that we do
            existingRhsBnodes = set()
            for stmt in self.rhsGraph:
                for t in stmt:
                    if isinstance(t, BNode):
                        existingRhsBnodes.add(t)
            # if existingRhsBnodes:
            # log.debug(f'{INDENT*6} mapping rhs bnodes {existingRhsBnodes} to new ones')

            for b in existingRhsBnodes:

                key = tuple(sorted(bound.binding.binding.items())), b
                self.rhsBnodeMap.setdefault(key, BNode())

                bound.binding.addNewBindings(CandidateBinding({b: self.rhsBnodeMap[key]}))

            # for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()):
            #     log.debug(f'{INDENT*6} adding to workingSet {lhsBoundStmt=}')
            #     workingSet.add(lhsBoundStmt)
            # log.debug(f'{INDENT*6} rhsGraph is good: {list(self.rhsGraph)}')

            for newStmt in bound.binding.apply(self.rhsGraph):
                # log.debug(f'{INDENT*6} adding {newStmt=}')
                workingSet.add(newStmt)
                implied.add(newStmt)


class Inference:

    def __init__(self) -> None:
        self.rules = []

    def setRules(self, g: ConjunctiveGraph):
        self.rules: List[Rule] = []
        for stmt in g:
            if stmt[1] == LOG['implies']:
                self.rules.append(Rule(stmt[0], stmt[2]))
            # other stmts should go to a default working set?

    @INFER_CALLS.time()
    def infer(self, graph: Graph):
        """
        returns new graph of inferred statements.
        """
        n = graph.__len__()
        INFER_GRAPH_SIZE.observe(n)
        log.info(f'{INDENT*0} Begin inference of graph len={n} with rules len={len(self.rules)}:')
        startTime = time.time()
        stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
        # everything that is true: the input graph, plus every rule conclusion we can make
        workingSet = Graph()
        workingSet += graph

        # just the statements that came from RHS's of rules that fired.
        implied = ConjunctiveGraph()

        bailout_iterations = 100
        delta = 1
        stats['initWorkingSet'] = cast(int, workingSet.__len__())
        while delta > 0 and bailout_iterations > 0:
            log.debug('')
            log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
            bailout_iterations -= 1
            delta = -len(implied)
            self._iterateAllRules(workingSet, implied, stats)
            delta += len(implied)
            stats['iterations'] += 1
            log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
        stats['timeSpent'] = round(time.time() - startTime, 3)
        stats['impliedStmts'] = len(implied)
        log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:')
        for st in implied:
            log.info(f'{INDENT*1} {st}')
        return implied

    def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
        for i, rule in enumerate(self.rules):
            self._logRuleApplicationHeader(workingSet, i, rule)
            rule.applyRule(workingSet, implied, stats)

    def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
        if not log.isEnabledFor(logging.DEBUG):
            return

        log.debug('')
        log.debug(f'{INDENT*2} workingSet:')
        for j, stmt in enumerate(sorted(workingSet)):
            log.debug(f'{INDENT*3} ({j}) {stmt}')

        log.debug('')
        log.debug(f'{INDENT*2}-applying rule {i}')
        log.debug(f'{INDENT*3} rule def lhs:')
        for stmt in r.lhsGraph:
            log.debug(f'{INDENT*4} {stmt}')
        log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')


def graphDump(g: Union[Graph, List[Triple]]):
    if not isinstance(g, Graph):
        g2 = Graph()
        g2 += g
        g = g2
    g.bind('', ROOM)
    g.bind('ex', Namespace('http://example.com/'))
    lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
    lines = [line.strip() for line in lines if not line.startswith('@prefix')]
    return ' '.join(lines)


def _organize(candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Tuple[List[BindableTerm], List[List[Node]]]:
    items = list(candidateTermMatches.items())
    items.sort()
    orderedVars: List[BindableTerm] = []
    orderedValueSets: List[List[Node]] = []
    for v, vals in items:
        orderedVars.append(v)
        orderedValues: List[Node] = list(vals)
        orderedValues.sort(key=str)
        orderedValueSets.append(orderedValues)

    return orderedVars, orderedValueSets