changeset 4:dc4f852d0d70

reformat and add some types
author drewp@bigasterisk.com
date Wed, 24 Nov 2021 19:47:35 -0800
parents 703adc4f78b1
children 506f6941a38c
files patchablegraph.py patchsource.py
diffstat 2 files changed, 91 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/patchablegraph.py	Wed Nov 24 19:47:06 2021 -0800
+++ b/patchablegraph.py	Wed Nov 24 19:47:35 2021 -0800
@@ -20,18 +20,23 @@
 differences between RDF graphs
 
 """
-import json, logging, itertools, html
+import html
+import itertools
+import json
+import logging
+from typing import Callable, List, Optional, cast
 
+import cyclone.sse
+import cyclone.web
+from cycloneerr import PrettyErrorHandler
 from prometheus_client import Counter, Gauge, Summary
 from rdfdb.grapheditapi import GraphEditApi
+from rdfdb.patch import Patch
+from rdfdb.rdflibpatch import inGraph, patchQuads
 from rdflib import ConjunctiveGraph
 from rdflib.namespace import NamespaceManager
 from rdflib.parser import StringInputSource
 from rdflib.plugins.serializers.jsonld import from_rdf
-import cyclone.sse
-from cycloneerr import PrettyErrorHandler
-from rdfdb.patch import Patch
-from rdfdb.rdflibpatch import patchQuads, inGraph
 
 log = logging.getLogger('patchablegraph')
 
@@ -46,20 +51,24 @@
 def _graphFromQuads2(q):
     g = ConjunctiveGraph()
     #g.addN(q) # no effect on nquad output
-    for s,p,o,c in q:
-        g.get_context(c).add((s,p,o)) # kind of works with broken rdflib nquad serializer code
+    for s, p, o, c in q:
+        g.get_context(c).add((s, p, o))  # kind of works with broken rdflib nquad serializer code
         #g.store.add((s,p,o), c) # no effect on nquad output
     return g
 
-def jsonFromPatch(p):
-    return json.dumps({'patch': {
-        'adds': from_rdf(_graphFromQuads2(p.addQuads)),
-        'deletes': from_rdf(_graphFromQuads2(p.delQuads)),
-    }})
-patchAsJson = jsonFromPatch # deprecated name
+
+def jsonFromPatch(p: Patch) -> str:
+    return json.dumps(
+        {'patch': {
+            'adds': from_rdf(_graphFromQuads2(p.addQuads)),
+            'deletes': from_rdf(_graphFromQuads2(p.delQuads)),
+        }})
 
 
-def patchFromJson(j):
+patchAsJson = jsonFromPatch  # deprecated name
+
+
+def patchFromJson(j: str) -> Patch:
     body = json.loads(j)['patch']
     a = ConjunctiveGraph()
     a.parse(StringInputSource(json.dumps(body['adds']).encode('utf8')), format='json-ld')
@@ -67,28 +76,34 @@
     d.parse(StringInputSource(json.dumps(body['deletes']).encode('utf8')), format='json-ld')
     return Patch(addGraph=a, delGraph=d)
 
-def graphAsJson(g):
+
+def graphAsJson(g: ConjunctiveGraph) -> str:
     # This is not the same as g.serialize(format='json-ld')! That
     # version omits literal datatypes.
     return json.dumps(from_rdf(g))
 
+
 _graphsInProcess = itertools.count()
+
+
 class PatchableGraph(GraphEditApi):
     """
     Master graph that you modify with self.patch, and we get the
     updates to all current listeners.
     """
-    def __init__(self):
-        self._graph = ConjunctiveGraph()
-        self._observers = []
-        scales.init(self, '/patchableGraph%s' % next(_graphsInProcess))
 
-    _serialize = scales.PmfStat('serialize')
-    def serialize(self, *arg, **kw):
-        with self._serialize.time():
-            return self._graph.serialize(*arg, **kw)
+    def __init__(self, label: Optional[str] = None):
+        self._graph = ConjunctiveGraph()
+        self._observers: List[Callable[[str], None]] = []
+        if label is None:
+            label = f'patchableGraph{next(_graphsInProcess)}'
+        self.label = label
 
-    def patch(self, p):
+    def serialize(self, *arg, **kw) -> bytes:
+        with SERIALIZE_CALLS.labels(graph=self.label).time():
+            return cast(bytes, self._graph.serialize(*arg, **kw))
+
+    def patch(self, p: Patch):
         with PATCH_CALLS.labels(graph=self.label).time():
             # assuming no stmt is both in p.addQuads and p.delQuads.
             dels = set([q for q in p.delQuads if inGraph(q, self._graph)])
@@ -96,30 +111,27 @@
             minimizedP = Patch(addQuads=adds, delQuads=dels)
             if minimizedP.isNoop():
                 return
-            patchQuads(self._graph,
-                       deleteQuads=dels,
-                       addQuads=adds,
-                       perfect=False) # true?
+            patchQuads(self._graph, deleteQuads=dels, addQuads=adds, perfect=False)  # true?
             for ob in self._observers:
                 ob(patchAsJson(p))
             STATEMENT_COUNT.labels(graph=self.label).set(len(self._graph))
 
-    def asJsonLd(self):
+    def asJsonLd(self) -> str:
         return graphAsJson(self._graph)
 
-    def addObserver(self, onPatch):
+    def addObserver(self, onPatch: Callable[[str], None]):
         self._observers.append(onPatch)
         OBSERVERS_CURRENT.labels(graph=self.label).set(len(self._observers))
         OBSERVERS_ADDED.labels(graph=self.label).inc()
 
-    def removeObserver(self, onPatch):
+    def removeObserver(self, onPatch: Callable[[str], None]):
         try:
             self._observers.remove(onPatch)
         except ValueError:
             pass
         self._currentObservers = len(self._observers)
 
-    def setToGraph(self, newGraph):
+    def setToGraph(self, newGraph: ConjunctiveGraph):
         self.patch(Patch.fromDiff(self._graph, newGraph))
 
 
@@ -127,6 +139,7 @@
 
 
 class CycloneGraphHandler(PrettyErrorHandler, cyclone.web.RequestHandler):
+
     def initialize(self, masterGraph: PatchableGraph):
         self.masterGraph = masterGraph
 
@@ -227,6 +240,7 @@
     response to send 'x-accel-buffering: no', per
     http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering
     """
+
     def __init__(self, application, request, masterGraph):
         cyclone.sse.SSEHandler.__init__(self, application, request)
         self.masterGraph = masterGraph
--- a/patchsource.py	Wed Nov 24 19:47:06 2021 -0800
+++ b/patchsource.py	Wed Nov 24 19:47:35 2021 -0800
@@ -1,42 +1,54 @@
-import logging, time
+import logging
+import time
 import traceback
-from rdflib import ConjunctiveGraph
-from rdflib.parser import StringInputSource
-from twisted.internet import reactor, defer
+from typing import Dict, Optional, Protocol
 
 from rdfdb.patch import Patch
+from rdflib import ConjunctiveGraph, URIRef
+from rdflib.parser import StringInputSource
+from twisted.internet import defer, reactor
 from twisted_sse.eventsource import EventSource
 
 from .patchablegraph import patchFromJson
 
 log = logging.getLogger('fetch')
 
+
+class _Listener(Protocol):
+    def __call__(
+            self,
+            p: Patch,
+            fullGraph: bool,  # True if the  patch is the initial full graph.
+    ) -> None:
+        ...
+
+
 class PatchSource(object):
     """wrap EventSource so it emits Patch objects and has an explicit stop method."""
-    def __init__(self, url, agent):
-        self.url = str(url)
+    def __init__(self, url: str, agent: str):
+        self.url = url
 
         # add callbacks to these to learn if we failed to connect
         # (approximately) or if the ccnnection was unexpectedly lost
         self.connectionFailed = defer.Deferred()
         self.connectionLost = defer.Deferred()
-        
+
         self._listeners = set()
         log.info('start read from %s', url)
         self._startReadTime = time.time()
-        self._patchesReceived = 0 # including fullgraph
+        self._patchesReceived = 0  # including fullgraph
         # note: fullGraphReceived isn't guaranteed- the stream could
         # start with patches
         self._fullGraphReceived = False
-        self._eventSource = EventSource(url.toPython().encode('utf8'),
-                                        userAgent=agent)
+        self._eventSource: Optional[EventSource] = EventSource(
+            url.encode('utf8'), userAgent=agent)
 
         self._eventSource.addEventListener(b'fullGraph', self._onFullGraph)
         self._eventSource.addEventListener(b'patch', self._onPatch)
         self._eventSource.onerror(self._onError)
         self._eventSource.onConnectionLost = self._onDisconnect
 
-    def state(self):
+    def state(self) -> Dict:
         return {
             'url': self.url,
             'fullGraphReceived': self._fullGraphReceived,
@@ -48,8 +60,8 @@
             },
             'closed': self._eventSource is None,
         }
-        
-    def addPatchListener(self, func):
+
+    def addPatchListener(self, func: _Listener):
         """
         func(patch, fullGraph=[true if the patch is the initial fullgraph])
         """
@@ -57,10 +69,8 @@
 
     def stop(self):
         log.info('stop read from %s', self.url)
-        try:
-            self._eventSource.protocol.stopProducing() # needed?
-        except AttributeError:
-            pass
+        if self._eventSource is not None:
+            self._eventSource.protocol.stopProducing()  # needed?
         self._eventSource = None
 
     def _onDisconnect(self, reason):
@@ -75,7 +85,7 @@
         else:
             self.connectionLost.callback(msg)
 
-    def _onFullGraph(self, message):
+    def _onFullGraph(self, message: str):
         try:
             g = ConjunctiveGraph()
             g.parse(StringInputSource(message), format='json-ld')
@@ -87,8 +97,8 @@
         self._fullGraphReceived = True
         self._fullGraphTime = time.time()
         self._patchesReceived += 1
-            
-    def _onPatch(self, message):
+
+    def _onPatch(self, message: str):
         try:
             p = patchFromJson(message)
             self._sendPatch(p, fullGraph=False)
@@ -98,16 +108,17 @@
         self._latestPatchTime = time.time()
         self._patchesReceived += 1
 
-    def _sendPatch(self, p, fullGraph):
-        log.debug('PatchSource %s received patch %s (fullGraph=%s)',
-                  self.url, p.shortSummary(), fullGraph)
+    def _sendPatch(self, p: Patch, fullGraph: bool):
+        log.debug('PatchSource %s received patch %s (fullGraph=%s)', self.url,
+                  p.shortSummary(), fullGraph)
         for lis in self._listeners:
             lis(p, fullGraph=fullGraph)
-        
+
     def __del__(self):
         if self._eventSource:
             raise ValueError("PatchSource wasn't stopped before del")
 
+
 class ReconnectingPatchSource(object):
     """
     PatchSource api, but auto-reconnects internally and takes listener
@@ -116,8 +127,11 @@
 
     todo: generate connection stmts in here
     """
-    def __init__(self, url, listener, reconnectSecs=60, agent='unset'):
-        # type: (str, Any, Any, str)
+    def __init__(self,
+                 url: str,
+                 listener: _Listener,
+                 reconnectSecs=60,
+                 agent='unset'):
         self.url = url
         self._stopped = False
         self._listener = listener
@@ -131,7 +145,7 @@
         self._ps = PatchSource(self.url, agent=self.agent)
         self._ps.addPatchListener(self._onPatch)
         self._ps.connectionFailed.addCallback(self._onConnectionFailed)
-        self._ps.connectionLost.addCallback(self._onConnectionLost)        
+        self._ps.connectionLost.addCallback(self._onConnectionLost)
 
     def _onPatch(self, p, fullGraph):
         self._listener(p, fullGraph=fullGraph)
@@ -140,14 +154,13 @@
         return {
             'reconnectedPatchSource': self._ps.state(),
         }
-        
+
     def stop(self):
         self._stopped = True
         self._ps.stop()
-        
+
     def _onConnectionFailed(self, arg):
         reactor.callLater(self.reconnectSecs, self._reconnect)
-        
+
     def _onConnectionLost(self, arg):
-        reactor.callLater(self.reconnectSecs, self._reconnect)        
- 
+        reactor.callLater(self.reconnectSecs, self._reconnect)