view service/mqtt_to_rdf/inference.py @ 1627:ea559a846714

some shuffling, i don't know- i'm about to rewrite again
author drewp@bigasterisk.com
date Sat, 11 Sep 2021 23:27:32 -0700
parents 7b3656867185
children 2c85a4f5dd9c
line wrap: on
line source

"""
copied from reasoning 2021-08-29. probably same api. should
be able to lib/ this out
"""
import itertools
import logging
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterator, List, Set, Tuple, Union, cast

from prometheus_client import Summary, Histogram
from rdflib import BNode, Graph, Namespace, URIRef
from rdflib.graph import ConjunctiveGraph, ReadOnlyGraphAggregate
from rdflib.term import Node, Variable

from candidate_binding import CandidateBinding
from inference_types import (BindableTerm, EvaluationFailed, ReadOnlyWorkingSet, Triple)
from lhs_evaluation import Evaluation

log = logging.getLogger('infer')
INDENT = '    '

INFER_CALLS = Summary('inference_infer_calls', 'calls')
INFER_GRAPH_SIZE = Histogram('inference_graph_size', 'statements', buckets=[2**x for x in range(2, 20, 2)])

ROOM = Namespace("http://projects.bigasterisk.com/room/")
LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
MATH = Namespace('http://www.w3.org/2000/10/swap/math#')


@dataclass
class Lhs:
    graph: Graph

    def __post_init__(self):
        # do precomputation in here that's not specific to the workingSet
        self.staticRuleStmts = Graph()
        self.nonStaticRuleStmts = Graph()

        self.lhsBindables: Set[BindableTerm] = set()
        self.lhsBnodes: Set[BNode] = set()
        for ruleStmt in self.graph:
            varsAndBnodesInStmt = [term for term in ruleStmt if isinstance(term, (Variable, BNode))]
            self.lhsBindables.update(varsAndBnodesInStmt)
            self.lhsBnodes.update(x for x in varsAndBnodesInStmt if isinstance(x, BNode))
            if not varsAndBnodesInStmt:
                self.staticRuleStmts.add(ruleStmt)
            else:
                self.nonStaticRuleStmts.add(ruleStmt)

        self.nonStaticRuleStmtsSet = set(self.nonStaticRuleStmts)

        self.evaluations = list(Evaluation.findEvals(self.graph))

    def __repr__(self):
        return f"Lhs({graphDump(self.graph)})"

    def findCandidateBindings(self, knownTrue: ReadOnlyWorkingSet, stats) -> Iterator['BoundLhs']:
        """bindings that fit the LHS of a rule, using statements from workingSet and functions
        from LHS"""
        log.debug(f'{INDENT*3} nodesToBind: {self.lhsBindables}')
        stats['findCandidateBindingsCalls'] += 1

        if not self._allStaticStatementsMatch(knownTrue):
            stats['findCandidateBindingEarlyExits'] += 1
            return

        for binding in self._possibleBindings(knownTrue, stats):
            log.debug('')
            log.debug(f'{INDENT*4}*trying {binding.binding}')

            if not binding.verify(knownTrue):
                log.debug(f'{INDENT*4} this binding did not verify')
                stats['permCountFailingVerify'] += 1
                continue

            stats['permCountSucceeding'] += 1
            yield binding

    def _allStaticStatementsMatch(self, knownTrue: ReadOnlyWorkingSet) -> bool:
        # bug: see TestSelfFulfillingRule.test3 for a case where this rule's
        # static stmt is matched by a non-static stmt in the rule itself
        for ruleStmt in self.staticRuleStmts:
            if ruleStmt not in knownTrue:
                log.debug(f'{INDENT*3} {ruleStmt} not in working set- skip rule')
                return False
        return True

    def _possibleBindings(self, workingSet, stats) -> Iterator['BoundLhs']:
        """this yields at least the working bindings, and possibly others"""
        candidateTermMatches: Dict[BindableTerm, Set[Node]] = self._allCandidateTermMatches(workingSet)
        for bindRow in self._product(candidateTermMatches):
            try:
                yield BoundLhs(self, bindRow)
            except EvaluationFailed:
                stats['permCountFailingEval'] += 1

    def _allCandidateTermMatches(self, workingSet: ReadOnlyWorkingSet) -> Dict[BindableTerm, Set[Node]]:
        """the total set of terms each variable could possibly match"""

        candidateTermMatches: Dict[BindableTerm, Set[Node]] = defaultdict(set)
        for lhsStmt in self.graph:
            log.debug(f'{INDENT*4} possibles for this lhs stmt: {lhsStmt}')
            for i, trueStmt in enumerate(workingSet):
                # log.debug(f'{INDENT*5} consider this true stmt ({i}): {trueStmt}')

                for v, vals in self._bindingsFromStatement(lhsStmt, trueStmt):
                    candidateTermMatches[v].update(vals)

        return candidateTermMatches

    def _bindingsFromStatement(self, stmt1: Triple, stmt2: Triple) -> Iterator[Tuple[Variable, Set[Node]]]:
        """if these stmts match otherwise, what BNode or Variable mappings do we learn?
        
        e.g. stmt1=(?x B ?y) and stmt2=(A B C), then we yield (?x, {A}) and (?y, {C})
        or   stmt1=(_:x B C) and stmt2=(A B C), then we yield (_:x, {A})
        or   stmt1=(?x B C)  and stmt2=(A B D), then we yield nothing
        """
        bindingsFromStatement = {}
        for term1, term2 in zip(stmt1, stmt2):
            if isinstance(term1, (BNode, Variable)):
                bindingsFromStatement.setdefault(term1, set()).add(term2)
            elif term1 != term2:
                break
        else:
            for v, vals in bindingsFromStatement.items():
                log.debug(f'{INDENT*5} {v=} {vals=}')
                yield v, vals

    def _product(self, candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Iterator[CandidateBinding]:
        orderedVars, orderedValueSets = _organize(candidateTermMatches)
        self._logCandidates(orderedVars, orderedValueSets)
        log.debug(f'{INDENT*3} trying all permutations:')
        if not orderedValueSets:
            yield CandidateBinding({})
            return

        if not orderedValueSets or not all(orderedValueSets):
            # some var or bnode has no options at all
            return
        rings: List[Iterator[Node]] = [itertools.cycle(valSet) for valSet in orderedValueSets]
        currentSet: List[Node] = [next(ring) for ring in rings]
        starts = [valSet[-1] for valSet in orderedValueSets]
        while True:
            for col, curr in enumerate(currentSet):
                currentSet[col] = next(rings[col])
                log.debug(f'{INDENT*4} currentSet: {repr(currentSet)}')
                yield CandidateBinding(dict(zip(orderedVars, currentSet)))
                if curr is not starts[col]:
                    break
                if col == len(orderedValueSets) - 1:
                    return

    def _logCandidates(self, orderedVars, orderedValueSets):
        if not log.isEnabledFor(logging.DEBUG):
            return
        log.debug(f'{INDENT*3} resulting candidate terms:')
        for v, vals in zip(orderedVars, orderedValueSets):
            log.debug(f'{INDENT*4} {v!r} could be:')
            for val in vals:
                log.debug(f'{INDENT*5}{val!r}')


# @dataclass
# class CandidateTermMatches:
#     """lazily find the possible matches for this term"""
#     terms: List[BindableTerm]
#     lhs: Lhs
#     knownTrue: Graph
#     boundSoFar: CandidateBinding

#     def __post_init__(self):
#         self.results: List[Node] = []  # we have to be able to repeat the results

#         res: Set[Node] = set()
#         for trueStmt in self.knownTrue:  # all bound
#             lStmts = list(self.lhsStmtsContainingTerm())
#             log.debug(f'{INDENT*4} {trueStmt=} {len(lStmts)}')
#             for pat in self.boundSoFar.apply(lStmts, returnBoundStatementsOnly=False):
#                 log.debug(f'{INDENT*4} {pat=}')
#                 implied = self._stmtImplies(pat, trueStmt)
#                 if implied is not None:
#                     res.add(implied)
#         self.results = list(res)
#         # self.results.sort()

#         log.debug(f'{INDENT*3} CandTermMatches: {self.term} {graphDump(self.lhs.graph)} {self.boundSoFar=} ===> {self.results=}')

#     def lhsStmtsContainingTerm(self):
#         # lhs could precompute this
#         for lhsStmt in self.lhs.graph:
#             if self.term in lhsStmt:
#                 yield lhsStmt

#     def __iter__(self):
#         return iter(self.results)


@dataclass
class BoundLhs:
    lhs: Lhs
    binding: CandidateBinding

    def __post_init__(self):
        self.usedByFuncs = Graph()
        self._applyFunctions()

    def lhsStmtsWithoutEvals(self):
        for stmt in self.lhs.graph:
            if stmt in self.usedByFuncs:
                continue
            yield stmt

    def _applyFunctions(self):
        """may grow the binding with some results"""
        while True:
            delta = self._applyFunctionsIteration()
            if delta == 0:
                break

    def _applyFunctionsIteration(self):
        before = len(self.binding.binding)
        delta = 0
        for ev in self.lhs.evaluations:
            newBindings, usedGraph = ev.resultBindings(self.binding)
            self.usedByFuncs += usedGraph
            self.binding.addNewBindings(newBindings)
        delta = len(self.binding.binding) - before
        log.debug(f'{INDENT*4} eval rules made {delta} new bindings')
        return delta

    def verify(self, workingSet: ReadOnlyWorkingSet) -> bool:
        """Can this bound lhs be true all at once in workingSet?"""
        rem = cast(Set[Triple], self.lhs.nonStaticRuleStmtsSet.difference(self.usedByFuncs))
        boundLhs = self.binding.apply(rem)

        if log.isEnabledFor(logging.DEBUG):
            boundLhs = list(boundLhs)
            self._logVerifyBanner(boundLhs, workingSet)

        for stmt in boundLhs:
            log.debug(f'{INDENT*4} check for %s', stmt)

            if stmt not in workingSet:
                log.debug(f'{INDENT*5} stmt not known to be true')
                return False
        return True

    def _logVerifyBanner(self, boundLhs, workingSet: ReadOnlyWorkingSet):
        log.debug(f'{INDENT*4}/ verify all bindings against this boundLhs:')
        for stmt in sorted(boundLhs):
            log.debug(f'{INDENT*4}|{INDENT} {stmt}')

        # log.debug(f'{INDENT*4}| and against this workingSet:')
        # for stmt in sorted(workingSet):
        #     log.debug(f'{INDENT*4}|{INDENT} {stmt}')

        log.debug(f'{INDENT*4}\\')


@dataclass
class Rule:
    lhsGraph: Graph
    rhsGraph: Graph

    def __post_init__(self):
        self.lhs = Lhs(self.lhsGraph)

    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict):
        for bound in self.lhs.findCandidateBindings(ReadOnlyGraphAggregate([workingSet]), stats):
            log.debug(f'{INDENT*3} rule has a working binding:')

            for lhsBoundStmt in bound.binding.apply(bound.lhsStmtsWithoutEvals()):
                log.debug(f'{INDENT*4} adding {lhsBoundStmt=}')
                workingSet.add(lhsBoundStmt)
            for newStmt in bound.binding.apply(self.rhsGraph):
                log.debug(f'{INDENT*4} adding {newStmt=}')
                workingSet.add(newStmt)
                implied.add(newStmt)


class Inference:

    def __init__(self) -> None:
        self.rules = []

    def setRules(self, g: ConjunctiveGraph):
        self.rules: List[Rule] = []
        for stmt in g:
            if stmt[1] == LOG['implies']:
                self.rules.append(Rule(stmt[0], stmt[2]))
            # other stmts should go to a default working set?

    @INFER_CALLS.time()
    def infer(self, graph: Graph):
        """
        returns new graph of inferred statements.
        """
        n = graph.__len__()
        INFER_GRAPH_SIZE.observe(n)
        log.info(f'{INDENT*0} Begin inference of graph len={n} with rules len={len(self.rules)}:')
        startTime = time.time()
        stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
        # everything that is true: the input graph, plus every rule conclusion we can make
        workingSet = Graph()
        workingSet += graph

        # just the statements that came from RHS's of rules that fired.
        implied = ConjunctiveGraph()

        bailout_iterations = 100
        delta = 1
        stats['initWorkingSet'] = cast(int, workingSet.__len__())
        while delta > 0 and bailout_iterations > 0:
            log.debug('')
            log.info(f'{INDENT*1}*iteration ({bailout_iterations} left)')
            bailout_iterations -= 1
            delta = -len(implied)
            self._iterateAllRules(workingSet, implied, stats)
            delta += len(implied)
            stats['iterations'] += 1
            log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
        stats['timeSpent'] = round(time.time() - startTime, 3)
        stats['impliedStmts'] = len(implied)
        log.info(f'{INDENT*0} Inference done {dict(stats)}. Implied:')
        for st in implied:
            log.info(f'{INDENT*1} {st}')
        return implied

    def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats):
        for i, rule in enumerate(self.rules):
            self._logRuleApplicationHeader(workingSet, i, rule)
            rule.applyRule(workingSet, implied, stats)

    def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
        if not log.isEnabledFor(logging.DEBUG):
            return

        log.debug('')
        log.debug(f'{INDENT*2} workingSet:')
        for j, stmt in enumerate(sorted(workingSet)):
            log.debug(f'{INDENT*3} ({j}) {stmt}')

        log.debug('')
        log.debug(f'{INDENT*2}-applying rule {i}')
        log.debug(f'{INDENT*3} rule def lhs: {graphDump(r.lhsGraph)}')
        log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')


def graphDump(g: Union[Graph, List[Triple]]):
    if not isinstance(g, Graph):
        log.warning(f"it's a {type(g)}")
        g2 = Graph()
        g2 += g
        g = g2
    g.bind('', ROOM)
    g.bind('ex', Namespace('http://example.com/'))
    lines = cast(bytes, g.serialize(format='n3')).decode('utf8').splitlines()
    lines = [line.strip() for line in lines if not line.startswith('@prefix')]
    return ' '.join(lines)


def _organize(candidateTermMatches: Dict[BindableTerm, Set[Node]]) -> Tuple[List[BindableTerm], List[List[Node]]]:
    items = list(candidateTermMatches.items())
    items.sort()
    orderedVars: List[BindableTerm] = []
    orderedValueSets: List[List[Node]] = []
    for v, vals in items:
        orderedVars.append(v)
        orderedValues: List[Node] = list(vals)
        orderedValues.sort(key=str)
        orderedValueSets.append(orderedValues)

    return orderedVars, orderedValueSets