diff merge.py @ 12:032e59be8fe9

refactor to separate the nonweb stuff a bit, in prep for cyclone->starlette
author drewp@bigasterisk.com
date Fri, 25 Nov 2022 20:58:08 -0800
parents
children bfd95926be6e
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/merge.py	Fri Nov 25 20:58:08 2022 -0800
@@ -0,0 +1,187 @@
+import collections
+from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, NewType
+from prometheus_client import Summary
+
+from rdfdb.patch import Patch
+from rdflib import Namespace, URIRef
+from rdflib.term import Node
+from standardservice.logsetup import enableTwistedLog, log
+from patchsink import PatchSink
+LOCAL_STATEMENTS_PATCH_CALLS = Summary("local_statements_patch_calls", 'calls')
+MAKE_SYNC_PATCH_CALLS = Summary("make_sync_patch_calls", 'calls')
+REPLACE_SOURCE_STATEMENTS_CALLS = Summary("replace_source_statements_calls", 'calls')
+
+SourceUri = NewType('SourceUri', URIRef)
+
+Statement = Tuple[Node, Node, Node, Node]
+StatementTable = Dict[Statement, Tuple[Set[SourceUri], Set[PatchSink]]]
+
+ROOM = Namespace("http://projects.bigasterisk.com/room/")
+COLLECTOR = SourceUri(URIRef('http://bigasterisk.com/sse_collector/'))
+
+
+def abbrevTerm(t: Union[URIRef, Node]) -> Union[str, Node]:
+    if isinstance(t, URIRef):
+        return (t.replace('http://projects.bigasterisk.com/room/', 'room:').replace('http://projects.bigasterisk.com/device/',
+                                                                                    'dev:').replace('http://bigasterisk.com/sse_collector/', 'sc:'))
+    return t
+
+
+def abbrevStmt(stmt: Statement) -> str:
+    return '(%s %s %s %s)' % (abbrevTerm(stmt[0]), abbrevTerm(stmt[1]), abbrevTerm(stmt[2]), abbrevTerm(stmt[3]))
+
+
+class LocalStatements(object):
+    """
+    functions that make statements originating from sse_collector itself
+    """
+
+    def __init__(self, applyPatch: Callable[[SourceUri, Patch], None]):
+        self.applyPatch = applyPatch
+        self._sourceState: Dict[SourceUri, Optional[URIRef]] = {}  # source: state URIRef
+
+    @LOCAL_STATEMENTS_PATCH_CALLS.time()
+    def setSourceState(self, source: SourceUri, state: Optional[URIRef]):
+        """
+        add a patch to the COLLECTOR graph about the state of this
+        source. state=None to remove the source.
+        """
+        oldState = self._sourceState.get(source, None)
+        if state == oldState:
+            return
+        log.info('source state %s -> %s', source, state)
+        if oldState is None:
+            self._sourceState[source] = state
+            self.applyPatch(COLLECTOR, Patch(addQuads=[
+                (COLLECTOR, ROOM['source'], source, COLLECTOR),
+                (source, ROOM['state'], state, COLLECTOR),
+            ]))
+        elif state is None:
+            del self._sourceState[source]
+            self.applyPatch(COLLECTOR, Patch(delQuads=[
+                (COLLECTOR, ROOM['source'], source, COLLECTOR),
+                (source, ROOM['state'], oldState, COLLECTOR),
+            ]))
+        else:
+            self._sourceState[source] = state
+            self.applyPatch(COLLECTOR, Patch(addQuads=[
+                (source, ROOM['state'], state, COLLECTOR),
+            ], delQuads=[
+                (source, ROOM['state'], oldState, COLLECTOR),
+            ]))
+
+
+class PostDeleter(object):
+
+    def __init__(self, statements: StatementTable):
+        self.statements = statements
+
+    def __enter__(self):
+        self._garbage: List[Statement] = []
+        return self
+
+    def add(self, stmt: Statement):
+        self._garbage.append(stmt)
+
+    def __exit__(self, type, value, traceback):
+        if type is not None:
+            raise NotImplementedError()
+        for stmt in self._garbage:
+            del self.statements[stmt]
+
+
+class ActiveStatements(object):
+
+    def __init__(self):
+        # This table holds statements asserted by any of our sources
+        # plus local statements that we introduce (source is
+        # http://bigasterisk.com/sse_collector/).
+        self.table: StatementTable = collections.defaultdict(lambda: (set(), set()))
+
+    def state(self) -> Dict:
+        return {
+            'len': len(self.table),
+        }
+
+    def postDeleteStatements(self) -> PostDeleter:
+        return PostDeleter(self.table)
+
+    def pprintTable(self) -> None:
+        for i, (stmt, (sources, handlers)) in enumerate(sorted(self.table.items())):
+            print("%03d. %-80s from %s to %s" % (i, abbrevStmt(stmt), [abbrevTerm(s) for s in sources], handlers))
+
+    @MAKE_SYNC_PATCH_CALLS.time()
+    def makeSyncPatch(self, handler: PatchSink, sources: Set[SourceUri]):
+        # todo: this could run all handlers at once, which is how we
+        # use it anyway
+        adds = []
+        dels = []
+
+        with self.postDeleteStatements() as garbage:
+            for stmt, (stmtSources, handlers) in self.table.items():
+                belongsInHandler = not sources.isdisjoint(stmtSources)
+                handlerHasIt = handler in handlers
+                # log.debug("%s belong=%s has=%s",
+                #           abbrevStmt(stmt), belongsInHandler, handlerHasIt)
+                if belongsInHandler and not handlerHasIt:
+                    adds.append(stmt)
+                    handlers.add(handler)
+                elif not belongsInHandler and handlerHasIt:
+                    dels.append(stmt)
+                    handlers.remove(handler)
+                    if not handlers and not stmtSources:
+                        garbage.add(stmt)
+
+        return Patch(addQuads=adds, delQuads=dels)
+
+    def applySourcePatch(self, source: SourceUri, p: Patch):
+        for stmt in p.addQuads:
+            sourceUrls, handlers = self.table[stmt]
+            if source in sourceUrls:
+                raise ValueError("%s added stmt that it already had: %s" % (source, abbrevStmt(stmt)))
+            sourceUrls.add(source)
+
+        with self.postDeleteStatements() as garbage:
+            for stmt in p.delQuads:
+                sourceUrls, handlers = self.table[stmt]
+                if source not in sourceUrls:
+                    raise ValueError("%s deleting stmt that it didn't have: %s" % (source, abbrevStmt(stmt)))
+                sourceUrls.remove(source)
+                # this is rare, since some handler probably still has
+                # the stmt we're deleting, but it can happen e.g. when
+                # a handler was just deleted
+                if not sourceUrls and not handlers:
+                    garbage.add(stmt)
+
+    @REPLACE_SOURCE_STATEMENTS_CALLS.time()
+    def replaceSourceStatements(self, source: SourceUri, stmts: Sequence[Statement]):
+        log.debug('replaceSourceStatements with %s stmts', len(stmts))
+        newStmts = set(stmts)
+
+        with self.postDeleteStatements() as garbage:
+            for stmt, (sources, handlers) in self.table.items():
+                if source in sources:
+                    if stmt not in stmts:
+                        sources.remove(source)
+                        if not sources and not handlers:
+                            garbage.add(stmt)
+                else:
+                    if stmt in stmts:
+                        sources.add(source)
+                newStmts.discard(stmt)
+
+        self.applySourcePatch(source, Patch(addQuads=newStmts, delQuads=[]))
+
+    def discardHandler(self, handler: PatchSink):
+        with self.postDeleteStatements() as garbage:
+            for stmt, (sources, handlers) in self.table.items():
+                handlers.discard(handler)
+                if not sources and not handlers:
+                    garbage.add(stmt)
+
+    def discardSource(self, source: SourceUri):
+        with self.postDeleteStatements() as garbage:
+            for stmt, (sources, handlers) in self.table.items():
+                sources.discard(source)
+                if not sources and not handlers:
+                    garbage.add(stmt)