view merge.py @ 13:bfd95926be6e default tip

initial port to starlette. missing some disconnect & cleanup functionality
author drewp@bigasterisk.com
date Sat, 26 Nov 2022 14:13:51 -0800
parents 032e59be8fe9
children
line wrap: on
line source

import logging
import collections
from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, NewType
from prometheus_client import Summary

from rdfdb.patch import Patch
from rdflib import Namespace, URIRef
from rdflib.term import Node
log = logging.getLogger()
LOCAL_STATEMENTS_PATCH_CALLS = Summary("local_statements_patch_calls", 'calls')
MAKE_SYNC_PATCH_CALLS = Summary("make_sync_patch_calls", 'calls')
REPLACE_SOURCE_STATEMENTS_CALLS = Summary("replace_source_statements_calls", 'calls')

# we deal with PatchSourceResponse in here, but only as objs to __contains__ and compare.
from patchsink import PatchSinkResponse
OutputHandler = Union[PatchSinkResponse, type(None)] #  A patchsink.PatchSinkResponse, but make type error if we try to look in it

SourceUri = NewType('SourceUri', URIRef)

Statement = Tuple[Node, Node, Node, Node]
StatementTable = Dict[Statement, Tuple[Set[SourceUri], Set[OutputHandler]]]

ROOM = Namespace("http://projects.bigasterisk.com/room/")
COLLECTOR = SourceUri(URIRef('http://bigasterisk.com/sse_collector/'))


def abbrevTerm(t: Union[URIRef, Node]) -> Union[str, Node]:
    if isinstance(t, URIRef):
        return (t.replace('http://projects.bigasterisk.com/room/', 'room:').replace('http://projects.bigasterisk.com/device/',
                                                                                    'dev:').replace('http://bigasterisk.com/sse_collector/', 'sc:'))
    return t


def abbrevStmt(stmt: Statement) -> str:
    return '(%s %s %s %s)' % (abbrevTerm(stmt[0]), abbrevTerm(stmt[1]), abbrevTerm(stmt[2]), abbrevTerm(stmt[3]))


class LocalStatements(object):
    """
    functions that make statements originating from sse_collector itself
    """

    def __init__(self, applyPatch: Callable[[SourceUri, Patch], None]):
        self.applyPatch = applyPatch
        self._sourceState: Dict[SourceUri, Optional[URIRef]] = {}  # source: state URIRef

    @LOCAL_STATEMENTS_PATCH_CALLS.time()
    def setSourceState(self, source: SourceUri, state: Optional[URIRef]):
        """
        add a patch to the COLLECTOR graph about the state of this
        source. state=None to remove the source.
        """
        oldState = self._sourceState.get(source, None)
        if state == oldState:
            return
        log.info('source state %s -> %s', source, state)
        if oldState is None:
            self._sourceState[source] = state
            self.applyPatch(COLLECTOR, Patch(addQuads=[
                (COLLECTOR, ROOM['source'], source, COLLECTOR),
                (source, ROOM['state'], state, COLLECTOR),
            ]))
        elif state is None:
            del self._sourceState[source]
            self.applyPatch(COLLECTOR, Patch(delQuads=[
                (COLLECTOR, ROOM['source'], source, COLLECTOR),
                (source, ROOM['state'], oldState, COLLECTOR),
            ]))
        else:
            self._sourceState[source] = state
            self.applyPatch(COLLECTOR, Patch(addQuads=[
                (source, ROOM['state'], state, COLLECTOR),
            ], delQuads=[
                (source, ROOM['state'], oldState, COLLECTOR),
            ]))


class PostDeleter(object):

    def __init__(self, statements: StatementTable):
        self.statements = statements

    def __enter__(self):
        self._garbage: List[Statement] = []
        return self

    def add(self, stmt: Statement):
        self._garbage.append(stmt)

    def __exit__(self, type, value, traceback):
        if type is not None:
            raise NotImplementedError()
        for stmt in self._garbage:
            del self.statements[stmt]


class ActiveStatements(object):

    def __init__(self):
        # This table holds statements asserted by any of our sources
        # plus local statements that we introduce (source is
        # http://bigasterisk.com/sse_collector/).
        self.table: StatementTable = collections.defaultdict(lambda: (set(), set()))

    def state(self) -> Dict:
        return {
            'len': len(self.table),
        }

    def postDeleteStatements(self) -> PostDeleter:
        return PostDeleter(self.table)

    def pprintTable(self) -> None:
        for i, (stmt, (sources, handlers)) in enumerate(sorted(self.table.items())):
            print("%03d. %-80s from %s to %s" % (i, abbrevStmt(stmt), [abbrevTerm(s) for s in sources], handlers))

    @MAKE_SYNC_PATCH_CALLS.time()
    def makeSyncPatch(self, handler: OutputHandler, sources: Set[SourceUri]):
        # todo: this could run all handlers at once, which is how we
        # use it anyway
        adds = []
        dels = []

        with self.postDeleteStatements() as garbage:
            for stmt, (stmtSources, handlers) in self.table.items():
                belongsInHandler = not sources.isdisjoint(stmtSources)
                handlerHasIt = handler in handlers
                # log.debug("%s belong=%s has=%s",
                #           abbrevStmt(stmt), belongsInHandler, handlerHasIt)
                if belongsInHandler and not handlerHasIt:
                    adds.append(stmt)
                    handlers.add(handler)
                elif not belongsInHandler and handlerHasIt:
                    dels.append(stmt)
                    handlers.remove(handler)
                    if not handlers and not stmtSources:
                        garbage.add(stmt)

        return Patch(addQuads=adds, delQuads=dels)

    def applySourcePatch(self, source: SourceUri, p: Patch):
        for stmt in p.addQuads:
            sourceUrls, handlers = self.table[stmt]
            if source in sourceUrls:
                raise ValueError("%s added stmt that it already had: %s" % (source, abbrevStmt(stmt)))
            sourceUrls.add(source)

        with self.postDeleteStatements() as garbage:
            for stmt in p.delQuads:
                sourceUrls, handlers = self.table[stmt]
                if source not in sourceUrls:
                    raise ValueError("%s deleting stmt that it didn't have: %s" % (source, abbrevStmt(stmt)))
                sourceUrls.remove(source)
                # this is rare, since some handler probably still has
                # the stmt we're deleting, but it can happen e.g. when
                # a handler was just deleted
                if not sourceUrls and not handlers:
                    garbage.add(stmt)

    @REPLACE_SOURCE_STATEMENTS_CALLS.time()
    def replaceSourceStatements(self, source: SourceUri, stmts: Sequence[Statement]):
        log.debug('replaceSourceStatements with %s stmts', len(stmts))
        newStmts = set(stmts)

        with self.postDeleteStatements() as garbage:
            for stmt, (sources, handlers) in self.table.items():
                if source in sources:
                    if stmt not in stmts:
                        sources.remove(source)
                        if not sources and not handlers:
                            garbage.add(stmt)
                else:
                    if stmt in stmts:
                        sources.add(source)
                newStmts.discard(stmt)

        self.applySourcePatch(source, Patch(addQuads=newStmts, delQuads=[]))

    def discardHandler(self, handler: OutputHandler):
        with self.postDeleteStatements() as garbage:
            for stmt, (sources, handlers) in self.table.items():
                handlers.discard(handler)
                if not sources and not handlers:
                    garbage.add(stmt)

    def discardSource(self, source: SourceUri):
        with self.postDeleteStatements() as garbage:
            for stmt, (sources, handlers) in self.table.items():
                sources.discard(source)
                if not sources and not handlers:
                    garbage.add(stmt)