diff service/mqtt_to_rdf/inference/inference.py @ 1727:23e6154e6c11

file moves
author drewp@bigasterisk.com
date Tue, 20 Jun 2023 23:26:24 -0700
parents service/mqtt_to_rdf/inference.py@88f6e9bf69d1
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/service/mqtt_to_rdf/inference/inference.py	Tue Jun 20 23:26:24 2023 -0700
@@ -0,0 +1,543 @@
+"""
+copied from reasoning 2021-08-29. probably same api. should
+be able to lib/ this out
+"""
+import itertools
+import logging
+import time
+from collections import defaultdict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict, Iterator, List, Optional, Tuple, Union, cast
+
+from prometheus_client import Histogram, Summary
+from rdflib import Graph, Namespace
+from rdflib.graph import ConjunctiveGraph
+from rdflib.term import Node, URIRef
+
+from inference.candidate_binding import CandidateBinding
+from inference.inference_types import (BindingUnknown, Inconsistent, RhsBnode, RuleUnboundBnode, Triple, WorkingSetBnode)
+from inference.lhs_evaluation import functionsFor
+from inference.rdf_debug import graphDump
+from inference.stmt_chunk import (AlignedRuleChunk, Chunk, ChunkedGraph, applyChunky)
+from inference.structured_log import StructuredLog
+
+log = logging.getLogger('infer')
+odolog = logging.getLogger('infer.odo')  # the "odometer" logic
+ringlog = logging.getLogger('infer.ring')  # for ChunkLooper
+
+INDENT = '    '
+
+INFER_CALLS = Summary('inference_infer_calls', 'calls')
+INFER_GRAPH_SIZE = Histogram('inference_graph_size', 'statements', buckets=[2**x for x in range(2, 20, 2)])
+
+ROOM = Namespace("http://projects.bigasterisk.com/room/")
+LOG = Namespace('http://www.w3.org/2000/10/swap/log#')
+MATH = Namespace('http://www.w3.org/2000/10/swap/math#')
+
+
+class NoOptions(ValueError):
+    """ChunkLooper has no possibilites to add to the binding; the whole rule must therefore not apply"""
+
+
+def debug(logger, slog: Optional[StructuredLog], msg):
+    logger.debug(msg)
+    if slog:
+        slog.say(msg)
+
+
+_chunkLooperShortId = itertools.count()
+
+
+@dataclass
+class ChunkLooper:
+    """given one LHS Chunk, iterate through the possible matches for it,
+    returning what bindings they would imply. Only distinct bindings are
+    returned. The bindings build on any `prev` ChunkLooper's results.
+
+    In the odometer metaphor used below, this is one of the rings.
+
+    This iterator is restartable."""
+    lhsChunk: Chunk
+    prev: Optional['ChunkLooper']
+    workingSet: 'ChunkedGraph'
+    slog: Optional[StructuredLog]
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}{self._shortId}{"<pastEnd>" if self.pastEnd() else ""}'
+
+    def __post_init__(self):
+        self._shortId = next(_chunkLooperShortId)
+        self._alignedMatches = list(self.lhsChunk.ruleMatchesFrom(self.workingSet))
+        del self.workingSet
+
+        # only ours- do not store prev, since it could change without us
+        self._current = CandidateBinding({})
+        self.currentSourceChunk: Optional[Chunk] = None  # for debugging only
+        self._pastEnd = False
+        self._seenBindings: List[CandidateBinding] = []  # combined bindings (up to our ring) that we've returned
+
+        if ringlog.isEnabledFor(logging.DEBUG):
+            ringlog.debug('')
+            msg = f'{INDENT*6} introducing {self!r}({self.lhsChunk}, {self._alignedMatches=})'
+            msg = msg.replace('AlignedRuleChunk', f'\n{INDENT*12}AlignedRuleChunk')
+            ringlog.debug(msg)
+
+        self.restart()
+
+    def _prevBindings(self) -> CandidateBinding:
+        if not self.prev or self.prev.pastEnd():
+            return CandidateBinding({})
+
+        return self.prev.currentBinding()
+
+    def advance(self):
+        """update _current to a new set of valid bindings we haven't seen (since
+        last restart), or go into pastEnd mode. Note that _current is just our
+        contribution, but returned valid bindings include all prev rings."""
+        if self._pastEnd:
+            raise NotImplementedError('need restart')
+        ringlog.debug('')
+        debug(ringlog, self.slog, f'{INDENT*6} --> {self}.advance start:')
+
+        self._currentIsFromFunc = None
+        augmentedWorkingSet: List[AlignedRuleChunk] = []
+        if self.prev is None:
+            augmentedWorkingSet = self._alignedMatches
+        else:
+            augmentedWorkingSet = list(applyChunky(self.prev.currentBinding(), self._alignedMatches))
+
+        if self._advanceWithPlainMatches(augmentedWorkingSet):
+            debug(ringlog, self.slog, f'{INDENT*6} <-- {self}.advance finished with plain matches')
+            return
+
+        if self._advanceWithFunctions():
+            debug(ringlog, self.slog, f'{INDENT*6} <-- {self}.advance finished with function matches')
+            return
+
+        debug(ringlog, self.slog, f'{INDENT*6} <-- {self}.advance had nothing and is now past end')
+        self._pastEnd = True
+
+    def _advanceWithPlainMatches(self, augmentedWorkingSet: List[AlignedRuleChunk]) -> bool:
+        # if augmentedWorkingSet:
+        #     debug(ringlog, self.slog, f'{INDENT*7} {self} mines {len(augmentedWorkingSet)} matching augmented statements')
+        # for s in augmentedWorkingSet:
+        #     debug(ringlog, self.slog, f'{INDENT*8} {s}')
+
+        for aligned in augmentedWorkingSet:
+            try:
+                newBinding = aligned.newBindingIfMatched(self._prevBindings())
+            except Inconsistent as exc:
+                debug(ringlog, self.slog, f'{INDENT*7} ChunkLooper{self._shortId} - {aligned} would be inconsistent with prev bindings ({exc})')
+                continue
+
+            if self._testAndKeepNewBinding(newBinding, aligned.workingSetChunk):
+                return True
+        return False
+
+    def _advanceWithFunctions(self) -> bool:
+        pred: Node = self.lhsChunk.predicate
+        if not isinstance(pred, URIRef):
+            raise NotImplementedError
+
+        for functionType in functionsFor(pred):
+            fn = functionType(self.lhsChunk)
+            # debug(ringlog, self.slog, f'{INDENT*7} ChunkLooper{self._shortId} advanceWithFunctions, {functionType=}')
+
+            try:
+                log.debug(f'fn.bind {self._prevBindings()} ...')
+                #fullBinding = self._prevBindings().copy()
+                newBinding = fn.bind(self._prevBindings())
+                log.debug(f'...makes {newBinding=}')
+            except BindingUnknown:
+                pass
+            else:
+                if newBinding is not None:
+                    self._currentIsFromFunc = fn
+                    if self._testAndKeepNewBinding(newBinding, self.lhsChunk):
+                        return True
+
+        return False
+
+    def _testAndKeepNewBinding(self, newBinding: CandidateBinding, sourceChunk: Chunk):
+        fullBinding: CandidateBinding = self._prevBindings().copy()
+        fullBinding.addNewBindings(newBinding)
+        isNew = fullBinding not in self._seenBindings
+
+        if ringlog.isEnabledFor(logging.DEBUG):
+            ringlog.debug(f'{INDENT*7} {self} considering {newBinding=} to make {fullBinding}. {isNew=}')
+        # if self.slog:
+        #     self.slog.looperConsider(self, newBinding, fullBinding, isNew)
+
+        if isNew:
+            self._seenBindings.append(fullBinding.copy())
+            self._current = newBinding
+            self.currentSourceChunk = sourceChunk
+            return True
+        return False
+
+    def localBinding(self) -> CandidateBinding:
+        if self.pastEnd():
+            raise NotImplementedError()
+        return self._current
+
+    def currentBinding(self) -> CandidateBinding:
+        if self.pastEnd():
+            raise NotImplementedError()
+        together = self._prevBindings().copy()
+        together.addNewBindings(self._current)
+        return together
+
+    def pastEnd(self) -> bool:
+        return self._pastEnd
+
+    def restart(self):
+        try:
+            self._pastEnd = False
+            self._seenBindings = []
+            self.advance()
+            if self.pastEnd():
+                raise NoOptions()
+        finally:
+            debug(ringlog, self.slog, f'{INDENT*7} ChunkLooper{self._shortId} restarts: pastEnd={self.pastEnd()}')
+
+
+@dataclass
+class Lhs:
+    graph: ChunkedGraph  # our full LHS graph, as input. See below for the statements partitioned into groups.
+
+    def __post_init__(self):
+
+        self.myPreds = self.graph.allPredicatesExceptFunctions()
+
+    def __repr__(self):
+        return f"Lhs({self.graph!r})"
+
+    def findCandidateBindings(self, knownTrue: ChunkedGraph, stats, slog: Optional[StructuredLog], ruleStatementsIterationLimit) -> Iterator['BoundLhs']:
+        """distinct bindings that fit the LHS of a rule, using statements from
+        workingSet and functions from LHS"""
+        if not self.graph:
+            # special case- no LHS!
+            yield BoundLhs(self, CandidateBinding({}))
+            return
+
+        if self._checkPredicateCounts(knownTrue):
+            stats['_checkPredicateCountsCulls'] += 1
+            return
+
+        if not all(ch in knownTrue for ch in self.graph.staticChunks):
+            stats['staticStmtCulls'] += 1
+            return
+        # After this point we don't need to consider self.graph.staticChunks.
+
+        if not self.graph.patternChunks and not self.graph.chunksUsedByFuncs:
+            # static only
+            yield BoundLhs(self, CandidateBinding({}))
+            return
+
+        log.debug('')
+        try:
+            chunkStack = self._assembleRings(knownTrue, stats, slog)
+        except NoOptions:
+            ringlog.debug(f'{INDENT*5} start up with no options; 0 bindings')
+            return
+        log.debug('')
+        log.debug('')
+        self._debugChunkStack('time to spin: initial odometer is', chunkStack)
+
+        if slog:
+            slog.say('time to spin')
+            slog.odometer(chunkStack)
+        self._assertAllRingsAreValid(chunkStack)
+
+        lastRing = chunkStack[-1]
+        iterCount = 0
+        while True:
+            iterCount += 1
+            if iterCount > ruleStatementsIterationLimit:
+                raise ValueError('rule too complex')
+
+            log.debug(f'{INDENT*4} vv findCandBindings iteration {iterCount}')
+
+            yield BoundLhs(self, lastRing.currentBinding())
+
+            # self._debugChunkStack('odometer', chunkStack)
+
+            done = self._advanceTheStack(chunkStack)
+
+            self._debugChunkStack(f'odometer after ({done=})', chunkStack)
+            if slog:
+                slog.odometer(chunkStack)
+
+            log.debug(f'{INDENT*4} ^^ findCandBindings iteration done')
+            if done:
+                break
+
+    def _debugChunkStack(self, label: str, chunkStack: List[ChunkLooper]):
+        odolog.debug(f'{INDENT*4} {label}:')
+        for i, l in enumerate(chunkStack):
+            odolog.debug(f'{INDENT*5} [{i}] {l} curbind={l.localBinding() if not l.pastEnd() else "<end>"}')
+
+    def _checkPredicateCounts(self, knownTrue):
+        """raise NoOptions quickly in some cases"""
+
+        if self.graph.noPredicatesAppear(self.myPreds):
+            log.debug(f'{INDENT*3} checkPredicateCounts does cull because not all {self.myPreds=} are in knownTrue')
+            return True
+        log.debug(f'{INDENT*3} checkPredicateCounts does not cull because all {self.myPreds=} are in knownTrue')
+        return False
+
+    def _assembleRings(self, knownTrue: ChunkedGraph, stats, slog) -> List[ChunkLooper]:
+        """make ChunkLooper for each stmt in our LHS graph, but do it in a way that they all
+        start out valid (or else raise NoOptions). static chunks have already been confirmed."""
+
+        log.debug(f'{INDENT*4} stats={dict(stats)}')
+        odolog.debug(f'{INDENT*3} build new ChunkLooper stack')
+        chunks = list(self.graph.patternChunks.union(self.graph.chunksUsedByFuncs))
+        chunks.sort(key=None)
+        odolog.info(f' {INDENT*3} taking permutations of {len(chunks)=}')
+
+        permsTried = 0
+
+        for perm in self._partitionedGraphPermutations():
+            looperRings: List[ChunkLooper] = []
+            prev: Optional[ChunkLooper] = None
+            if odolog.isEnabledFor(logging.DEBUG):
+                odolog.debug(f'{INDENT*4} [perm {permsTried}] try rule chunks in this order: {"  THEN  ".join(repr(p) for p in perm)}')
+
+            for ruleChunk in perm:
+                try:
+                    # These are getting rebuilt a lot which takes time. It would
+                    # be nice if they could accept a changing `prev` order
+                    # (which might already be ok).
+                    looper = ChunkLooper(ruleChunk, prev, knownTrue, slog)
+                except NoOptions:
+                    odolog.debug(f'{INDENT*5} permutation didnt work, try another')
+                    break
+                looperRings.append(looper)
+                prev = looperRings[-1]
+            else:
+                # bug: At this point we've only shown that these are valid
+                # starting rings. The rules might be tricky enough that this
+                # permutation won't get us to the solution.
+                return looperRings
+            if permsTried > 50000:
+                raise NotImplementedError(f'trying too many permutations {len(chunks)=}')
+            permsTried += 1
+
+        stats['permsTried'] += permsTried
+        odolog.debug(f'{INDENT*5} no perms worked- rule cannot match anything')
+        raise NoOptions()
+
+    def _unpartitionedGraphPermutations(self) -> Iterator[Tuple[Chunk, ...]]:
+        for perm in itertools.permutations(sorted(list(self.graph.patternChunks.union(self.graph.chunksUsedByFuncs)))):
+            yield perm
+
+    def _partitionedGraphPermutations(self) -> Iterator[Tuple[Chunk, ...]]:
+        """always puts function chunks after pattern chunks
+
+        (and, if we cared, static chunks could go before that. Currently they're
+        culled out elsewhere, but that's done as a special case)
+        """
+        tupleOfNoChunks: Tuple[Chunk, ...] = ()
+        pats = sorted(self.graph.patternChunks)
+        funcs = sorted(self.graph.chunksUsedByFuncs)
+        for patternPart in itertools.permutations(pats) if pats else [tupleOfNoChunks]:
+            for funcPart in itertools.permutations(funcs) if funcs else [tupleOfNoChunks]:
+                perm = patternPart + funcPart
+                yield perm
+
+    def _advanceTheStack(self, looperRings: List[ChunkLooper]) -> bool:
+        toRestart: List[ChunkLooper] = []
+        pos = len(looperRings) - 1
+        while True:
+            looperRings[pos].advance()
+            if looperRings[pos].pastEnd():
+                if pos == 0:
+                    return True
+                toRestart.append(looperRings[pos])
+                pos -= 1
+            else:
+                break
+        for ring in reversed(toRestart):
+            ring.restart()
+        return False
+
+    def _assertAllRingsAreValid(self, looperRings):
+        if any(ring.pastEnd() for ring in looperRings):  # this is an unexpected debug assertion
+            odolog.warning(f'{INDENT*4} some rings started at pastEnd {looperRings}')
+            raise NoOptions()
+
+
+@dataclass
+class BoundLhs:
+    lhs: Lhs
+    binding: CandidateBinding
+
+
+@dataclass
+class Rule:
+    lhsGraph: Graph
+    rhsGraph: Graph
+
+    def __post_init__(self):
+        self.lhs = Lhs(ChunkedGraph(self.lhsGraph, RuleUnboundBnode, functionsFor))
+
+        self.maps = {}
+
+        self.rhsGraphConvert: List[Triple] = []
+        for s, p, o in self.rhsGraph:
+            from rdflib import BNode
+            if isinstance(s, BNode):
+                s = RhsBnode(s)
+            if isinstance(p, BNode):
+                p = RhsBnode(p)
+            if isinstance(o, BNode):
+                o = RhsBnode(o)
+            self.rhsGraphConvert.append((s, p, o))
+
+    def applyRule(self, workingSet: Graph, implied: Graph, stats: Dict, slog: Optional[StructuredLog], ruleStatementsIterationLimit):
+        # this does not change for the current applyRule call. The rule will be
+        # tried again in an outer loop, in case it can produce more.
+        workingSetChunked = ChunkedGraph(workingSet, WorkingSetBnode, functionsFor)
+
+        for bound in self.lhs.findCandidateBindings(workingSetChunked, stats, slog, ruleStatementsIterationLimit):
+            if slog:
+                slog.foundBinding(bound)
+            log.debug(f'{INDENT*5} +rule has a working binding: {bound}')
+
+            newStmts = self.generateImpliedFromRhs(bound.binding)
+
+            for newStmt in newStmts:
+                # log.debug(f'{INDENT*6} adding {newStmt=}')
+                workingSet.add(newStmt)
+                implied.add(newStmt)
+
+    def generateImpliedFromRhs(self, binding: CandidateBinding) -> List[Triple]:
+
+        out: List[Triple] = []
+
+        # Each time the RHS is used (in a rule firing), its own BNodes (which
+        # are subtype RhsBnode) need to be turned into distinct ones. Note that
+        # bnodes that come from the working set should not be remapped.
+        rhsBnodeMap: Dict[RhsBnode, WorkingSetBnode] = {}
+
+        # but, the iteration loop could come back with the same bindings again
+        key = binding.key()
+        rhsBnodeMap = self.maps.setdefault(key, {})
+
+        for stmt in binding.apply(self.rhsGraphConvert):
+
+            outStmt: List[Node] = []
+
+            for t in stmt:
+                if isinstance(t, RhsBnode):
+                    if t not in rhsBnodeMap:
+                        rhsBnodeMap[t] = WorkingSetBnode()
+                    t = rhsBnodeMap[t]
+
+                outStmt.append(t)
+
+            log.debug(f'{INDENT*6} rhs stmt {stmt} became {outStmt}')
+            out.append((outStmt[0], outStmt[1], outStmt[2]))
+
+        return out
+
+
+@dataclass
+class Inference:
+    rulesIterationLimit = 4
+    ruleStatementsIterationLimit = 5000
+
+    def __init__(self) -> None:
+        self.rules: List[Rule] = []
+        self._nonRuleStmts: List[Triple] = []
+
+    def setRules(self, g: ConjunctiveGraph):
+        self.rules = []
+        self._nonRuleStmts = []
+        for stmt in g:
+            if stmt[1] == LOG['implies']:
+                self.rules.append(Rule(stmt[0], stmt[2]))
+            else:
+                self._nonRuleStmts.append(stmt)
+
+    def nonRuleStatements(self) -> List[Triple]:
+        return self._nonRuleStmts
+
+    @INFER_CALLS.time()
+    def infer(self, graph: Graph, htmlLog: Optional[Path] = None):
+        """
+        returns new graph of inferred statements.
+        """
+        n = cast(int, graph.__len__())
+        INFER_GRAPH_SIZE.observe(n)
+        log.info(f'{INDENT*0} Begin inference of graph len={n} with rules len={len(self.rules)}:')
+        startTime = time.time()
+        stats: Dict[str, Union[int, float]] = defaultdict(lambda: 0)
+
+        # everything that is true: the input graph, plus every rule conclusion we can make
+        workingSet = Graph()
+        workingSet += self._nonRuleStmts
+        workingSet += graph
+
+        # just the statements that came from RHS's of rules that fired.
+        implied = ConjunctiveGraph()
+
+        slog = StructuredLog(htmlLog) if htmlLog else None
+
+        rulesIterations = 0
+        delta = 1
+        stats['initWorkingSet'] = cast(int, workingSet.__len__())
+        if slog:
+            slog.workingSet = workingSet
+
+        while delta > 0:
+            log.debug('')
+            log.info(f'{INDENT*1}*iteration {rulesIterations}')
+            if slog:
+                slog.startIteration(rulesIterations)
+
+            delta = -len(implied)
+            self._iterateAllRules(workingSet, implied, stats, slog)
+            delta += len(implied)
+            rulesIterations += 1
+            log.info(f'{INDENT*2} this inference iteration added {delta} more implied stmts')
+            if rulesIterations >= self.rulesIterationLimit:
+                raise ValueError(f"rule too complex after {rulesIterations=}")
+        stats['iterations'] = rulesIterations
+        stats['timeSpent'] = round(time.time() - startTime, 3)
+        stats['impliedStmts'] = len(implied)
+        log.info(f'{INDENT*0} Inference done {dict(stats)}.')
+        log.debug('Implied:')
+        log.debug(graphDump(implied))
+
+        if slog:
+            slog.render()
+            log.info(f'wrote {htmlLog}')
+
+        return implied
+
+    def _iterateAllRules(self, workingSet: Graph, implied: Graph, stats, slog: Optional[StructuredLog]):
+        for i, rule in enumerate(self.rules):
+            self._logRuleApplicationHeader(workingSet, i, rule)
+            if slog:
+                slog.rule(workingSet, i, rule)
+            rule.applyRule(workingSet, implied, stats, slog, self.ruleStatementsIterationLimit)
+
+    def _logRuleApplicationHeader(self, workingSet, i, r: Rule):
+        if not log.isEnabledFor(logging.DEBUG):
+            return
+
+        log.debug('')
+        log.debug(f'{INDENT*2} workingSet:')
+        # for j, stmt in enumerate(sorted(workingSet)):
+        #     log.debug(f'{INDENT*3} ({j}) {stmt}')
+        log.debug(f'{INDENT*3} {graphDump(workingSet, oneLine=False)}')
+
+        log.debug('')
+        log.debug(f'{INDENT*2}-applying rule {i}')
+        log.debug(f'{INDENT*3} rule def lhs:')
+        for stmt in sorted(r.lhs.graph.allChunks()):
+            log.debug(f'{INDENT*4} {stmt}')
+        log.debug(f'{INDENT*3} rule def rhs: {graphDump(r.rhsGraph)}')