changeset 105:4681ea3fcdf6

port to starlette asyncio
author drewp@bigasterisk.com
date Mon, 30 May 2022 20:38:53 -0700
parents d1fd6aeffb27
children b0f922c8c728
files rdfdb/service.py rdfdb/syncedgraph/currentstategraphapi.py rdfdb/syncedgraph/syncedgraph_base.py
diffstat 3 files changed, 112 insertions(+), 185 deletions(-) [+]
line wrap: on
line diff
--- a/rdfdb/service.py	Mon May 30 20:38:06 2022 -0700
+++ b/rdfdb/service.py	Mon May 30 20:38:53 2022 -0700
@@ -1,62 +1,29 @@
+import functools
 import itertools
-import json
 import logging
-import optparse
-import os
-import sys
 from pathlib import Path
-from typing import Dict, List, Optional, cast
-
-import cyclone.web
-import cyclone.websocket
-import twisted.internet.error
-import twisted.internet.reactor
-from rdflib import ConjunctiveGraph, Graph, URIRef
-from twisted.internet.inotify import IN_CREATE, INotify
-from twisted.internet.interfaces import IReactorCore
-from twisted.python.failure import Failure
-from twisted.python.filepath import FilePath
+from typing import Dict, Optional, cast
 
-from rdfdb.file_vs_uri import (DirUriMap, correctToTopdirPrefix, fileForUri, uriFromFile)
-from rdfdb.graphfile import GetSubgraph, GraphFile, PatchCb
-from rdfdb.patch import ALLSTMTS, Patch
-from rdfdb.rdflibpatch import patchQuads
-from prometheus_client import Gauge
-from prometheus_client.exposition import generate_latest
-from prometheus_client.registry import REGISTRY
-reactor = cast(IReactorCore, twisted.internet.reactor)
+from rdflib import URIRef
+from starlette.applications import Starlette
+from starlette.endpoints import WebSocketEndpoint
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response
+from starlette.routing import Route, WebSocketRoute
+from starlette.websockets import WebSocket
+from starlette_exporter import PrometheusMiddleware, handle_metrics
 
-STAT_CLIENTS = Gauge('clients', 'connected clients')
-# gatherProcessStats()
-# stats = scales.collection(
-#     '/webServer',
-#     scales.IntStat('liveClients'),
-#     scales.PmfStat('setAttr'),
-# )
-# graphStats = scales.collection(
-#     '/graph',
-#     scales.IntStat('statements'),
-#     scales.RecentFpsStat('patchFps'),
-# )
-# fileStats = scales.collection(
-#     '/file',
-#     scales.IntStat('mappedGraphFiles'),
-# )
+from rdfdb.file_vs_uri import DirUriMap
+from rdfdb.patch import Patch
+
+from .shared_graph import Db
 
 log = logging.getLogger('rdfdb')
 
-
-class WebsocketDisconnect(ValueError):
-    pass
-
-
-CtxPrefixes = Dict[Optional[URIRef], Dict[str, URIRef]]
-
-
 _wsClientSerial = itertools.count(0)
 
 
-class WebsocketClient(cyclone.websocket.WebSocketHandler):
+class SyncedGraphSocket(WebSocketEndpoint):
     """
     Send patches to the client (starting with a client who has 0
     statements) to keep it in sync with the graph.
@@ -71,80 +38,66 @@
     This socket may also carry some special messages meant for the
     rdfdb web UI, e.g. about who is connected, etc.
     """
-    connectionId: str
-
-    def connectionMade(self, *args, **kwargs) -> None:
-        self.connectionId = f'WS{next(_wsClientSerial)}'
-
-        self.sendMessage(json.dumps({'connectedAs': self.connectionId}))
-        log.info("new ws client %r", self)
-        self.settings.db.addClient(self)
+    encoding = "text"
 
-    def connectionLost(self, reason):
-        log.info("bye ws client %r: %s", self, reason)
-        self.settings.db.clientErrored(Failure(WebsocketDisconnect(reason)), self)
-
-    def messageReceived(self, message: bytes):
-        if message == b'PING':
-            self.sendMessage('PONG')
-            return
-        log.debug("got message from %r: %s", self, message[:32])
-        p = Patch(jsonRepr=message.decode('utf8'))
-        self.settings.db.patch(p, sender=self.connectionId)
-
-    def sendPatch(self, p: Patch):
-        self.sendMessage(p.makeJsonRepr())
+    def __init__(self, db: Db, *a, **kw):
+        WebSocketEndpoint.__init__(self, *a, **kw)
+        self.db = db
+        self.connectionId = f'WS{next(_wsClientSerial)}'  # unneeded?
+        log.info(f'ws connection init {self.connectionId}')
 
     def __repr__(self):
-        return f"<SyncedGraph client {self.connectionId}>"
+        return f'<WebSocket {self.connectionId}>'
+
+    async def on_connect(self, websocket: WebSocket):
+        await websocket.accept()
+        await websocket.send_json({'connectedAs': self.connectionId})
+        log.info(f"new ws client {self.connectionId}")
+
+        await self.db.addClient(self.connectionId, functools.partial(self._onPatch,websocket))
+
+    async def _onPatch(self, websocket: WebSocket, p: Patch):
+        await websocket.send_text(p.makeJsonRepr())
 
+    async def on_receive(self, websocket: WebSocket, data: str):
+        if data == 'PING':
+            await websocket.send_text('PONG')
+            return
+        log.debug("%r sends patch to us: %s", self, data[:64])
+        p = Patch(jsonRepr=data)
+        try:
+            await self.db.patch(p, sender=self.connectionId)
+        except ValueError as e:
+            log.warning(f'patch from {self!r} did not apply: {e!r}')
+            # here we should disconnect that client and make them reset
+            await websocket.close()
 
+    async def on_disconnect(self, websocket, close_code):
+        log.info("bye ws client %r: %s", self.connectionId, close_code)
+        self.db.clientDisconnected(self.connectionId)
 
 
-class GraphResource(cyclone.web.RequestHandler):
-
-    def get(self):
-        accept = self.request.headers.get('accept', '')
+def get_graph(db: Db, request: Request) -> Response:
+    accept = request.headers.get('accept', '')
+    if accept == 'text/plain':
+        format = 'nt'
+    elif accept == 'application/n-quads':
+        format = 'nquads'
+    else:
         format = 'n3'
-        if accept == 'text/plain':
-            format = 'nt'
-        elif accept == 'application/n-quads':
-            format = 'nquads'
-        elif accept == 'pickle':
-            # don't use this; it's just for speed comparison
-            import pickle as pickle
-            pickle.dump(self.settings.db.graph, self, protocol=2)
-            return
-        elif accept == 'msgpack':
-            self.write(repr(self.settings.db.graph.__getstate__))
-            return
-        self.write(self.settings.db.graph.serialize(format=format))
+    return PlainTextResponse(db.graph.serialize(format=format))
 
 
-class Prefixes(cyclone.web.RequestHandler):
-
-    def post(self):
-        suggestion = json.loads(self.request.body)
-        addlPrefixes = self.settings.db.watchedFiles.addlPrefixes
-        addlPrefixes.setdefault(URIRef(suggestion['ctx']), {}).update(suggestion['prefixes'])
+async def post_prefixes(db: Db, request: Request) -> Response:
+    suggestion = await request.json()
+    db.addPrefixes(
+        ctx=URIRef(suggestion['ctx']),  #
+        prefixes=dict((cast(str, k), URIRef(v)) for k, v in suggestion['prefixes'].items()))
+    return Response(status_code=204)
 
 
-class NoExts(cyclone.web.StaticFileHandler):
-    # .html pages can be get() without .html on them
-    def get(self, path, *args, **kw):
-        if path and '.' not in path:
-            path = path + ".html"
-        cyclone.web.StaticFileHandler.get(self, path, *args, **kw)
-
-
-class Metrics(cyclone.web.RequestHandler):
-
-    def get(self):
-        self.add_header('content-type', 'text/plain')
-        self.write(generate_latest(REGISTRY))
-
-
-def main(dirUriMap: Optional[DirUriMap] = None, prefixes: Optional[Dict[str, URIRef]] = None, port=9999):
+def makeApp(dirUriMap: Optional[DirUriMap] = None, prefixes: Optional[Dict[str, URIRef]] = None) -> Starlette:
+    log.info('makeApp start')
 
     if dirUriMap is None:
         dirUriMap = {Path('data/'): URIRef('http://example.com/data/')}
@@ -155,41 +108,21 @@
             'xsd': URIRef('http://www.w3.org/2001/XMLSchema#'),
         }
 
-    logging.basicConfig()
-    log = logging.getLogger()
-
-    parser = optparse.OptionParser()
-    parser.add_option("-v", "--verbose", action="store_true", help="logging.DEBUG")
-    (options, args) = parser.parse_args()
+    log.setLevel(logging.DEBUG if 1 else logging.INFO)
 
-    log.setLevel(logging.DEBUG if options.verbose else logging.INFO)
-
-    db = Db(dirUriMap=dirUriMap, addlPrefixes={None: prefixes})
-
-    from twisted.python import log as twlog
-    twlog.startLogging(sys.stdout)
+    log.info('setup watches')
+    db = Db(dirUriMap=dirUriMap, prefixes=prefixes)
 
-    reactor.listenTCP(
-        port,
-        cyclone.web.Application(
-            handlers=[
-                (r'/graph', GraphResource),
-                (r'/syncedGraph', WebsocketClient),
-                (r'/prefixes', Prefixes),
-                (r'/metrics', Metrics),
-                # (r'/stats/(.*)', StatsHandler, {
-                #     'serverName': 'rdfdb'
-                # }),
-                (r'/(.*)', NoExts, {
-                    "path": FilePath(__file__).sibling("web").path,
-                    "default_filename": "index.html"
-                }),
-            ],
-            debug=True,
-            db=db))
-    log.info("serving on %s" % port)
-    reactor.run()
+    app = Starlette(
+        debug=True,
+        routes=[
+            Route('/graph', functools.partial(get_graph, db)),
+            WebSocketRoute('/syncedGraph', functools.partial(SyncedGraphSocket, db)),
+            Route('/prefixes', functools.partial(post_prefixes, db)),
+        ],
+    )
 
+    app.add_middleware(PrometheusMiddleware)
+    app.add_route("/metrics", handle_metrics)
 
-if __name__ == '__main__':
-    main()
+    return app
--- a/rdfdb/syncedgraph/currentstategraphapi.py	Mon May 30 20:38:06 2022 -0700
+++ b/rdfdb/syncedgraph/currentstategraphapi.py	Mon May 30 20:38:53 2022 -0700
@@ -5,7 +5,7 @@
 from typing import Optional, Set, Tuple
 
 from rdfdb.rdflibpatch import contextsForStatement as rp_contextsForStatement
-from rdfdb.readonly_graph import ReadOnlyConjunctiveGraph
+from rdfdb.syncedgraph.readonly_graph import ReadOnlyConjunctiveGraph
 from rdfdb.syncedgraph.syncedgraph_base import SyncedGraphBase
 from rdflib import ConjunctiveGraph, URIRef
 from rdflib.term import Node
--- a/rdfdb/syncedgraph/syncedgraph_base.py	Mon May 30 20:38:06 2022 -0700
+++ b/rdfdb/syncedgraph/syncedgraph_base.py	Mon May 30 20:38:53 2022 -0700
@@ -91,10 +91,9 @@
         receiverHost is the hostname other nodes can use to talk to me
         """
         self.rdfdbRoot = rdfdbRoot
-        self.httpSession = aiohttp.ClientSession()
-        self._senderTask = asyncio.create_task(self._sender())
+        self._commTask = asyncio.create_task(self._communicate())
 
-        self._initiallySynced = asyncio.Future()
+        self.initiallySynced = asyncio.Future()
         self._graph = ConjunctiveGraph()
 
         # todo:
@@ -103,38 +102,42 @@
         # this needs more state to track if we're doing a resync (and
         # everything has to error or wait) or if we're live
 
-    async def _sender(self):
+    async def _communicate(self):
         while True:
-            with self.httpSession.ws_connect(self.rdfdbRoot.replace('http://', 'ws://') + 'syncedGraph') as ws:
-                async for msg in ws:
-                    log.info(f"server sent us {msg=}")
-                    # if msg.type == aiohttp.WSMsgType.TEXT:
-                    #     if msg.data == 'close cmd':
-                    #         await ws.close()
-                    #         break
-                    #     else:
-                    #         await ws.send_str(msg.data + '/answer')
-                    # elif msg.type == aiohttp.WSMsgType.ERROR:
-                    #     break
-            self.lostRdfdbConnection()
+            async with aiohttp.ClientSession() as sess:
+                async with sess.ws_connect(self.rdfdbRoot.replace('http://', 'ws://') + 'syncedGraph') as ws:
+                    self.ws = ws
+                    async for msg in ws:
+                        log.info(f"server sent us {repr(msg)[:200]}")
+                        try:
+                            self._onIncomingMsg(msg.data)
+                        except Exception:
+                            traceback.print_exc()
+                            raise
+
+            await self._lostRdfdbConnection()
             log.info("lost connection- retry")
             await asyncio.sleep(4)
 
-    async def init(self):
-        """return when we have the initial graph from server.
+    def _onIncomingMsg(self, body: str):
+        j = json.loads(body)
+        if 'connectedAs' in j:
+            self.connectionId = j['connectedAs']
+            log.info(f'rdfdb calls us {self.connectionId}')
+        elif 'patch' in j:
+            p = Patch(jsonRepr=body)  # todo: repeated parse
+            log.debug("received patch %s", p.shortSummary())
+            self._onPatchFromDb(p)
+        else:
+            log.warn('unknown msg from websocket: %s...', body[:32])
 
-        maybe this isn't really needed, as everything ought to be resilent
-        to the intial graph pouring in.
-        """
-        await self._initiallySynced
-
-    def lostRdfdbConnection(self) -> None:
+    async def _lostRdfdbConnection(self) -> None:
         self.isConnected = False
-        self.patch(Patch(delQuads=self._graph.quads()))
+        await self.patch(Patch(delQuads=self._graph.quads()))
         log.info(f'cleared graph to {len(self._graph)}')
         log.error('graph is not updating- you need to restart')
 
-    def resync(self):
+    async def _resync(self):
         """
         get the whole graph again from the server (e.g. we had a
         conflict while applying a patch and want to return to the
@@ -149,14 +152,8 @@
         UIs who want to show that we're doing a resync.
         """
         log.info('resync')
-        if self.currentClient:
-            self.currentClient.dropConnection()
+        await self.ws.close()
 
-    def _resyncGraph(self, response):
-        log.warn("new graph in")
-
-        if self.currentClient:
-            self.currentClient.dropConnection()
         # diff against old entire graph
         # broadcast that change
 
@@ -164,7 +161,7 @@
         # See AutoDepGraphApi
         pass
 
-    def patch(self, p: Patch) -> None:
+    async def patch(self, p: Patch) -> None:
         """send this patch to the server and apply it to our local
         graph and run handlers"""
 
@@ -185,12 +182,12 @@
             self._applyPatchLocally(p)
         except ValueError as e:
             log.error(e)
-            self.resync()
+            await self._resync()
             return
         log.debug('runDepsOnNewPatch')
         self.runDepsOnNewPatch(p)
         log.debug('sendPatch')
-        self.currentClient.sendPatch(p)
+        await self.ws.send_str(p.jsonRepr)
         log.debug('patch is done %s', debugKey)
 
     async def suggestPrefixes(self, ctx, prefixes):
@@ -206,7 +203,7 @@
         patchQuads(self._graph, p.delQuads, p.addQuads, perfect=True)
         log.debug("graph now has %s statements" % len(self._graph))
 
-    def onPatchFromDb(self, p):
+    def _onPatchFromDb(self, p: Patch):
         """
         central server has sent us a patch
         """
@@ -225,6 +222,3 @@
             # state since some dependencies may not have rerun
             traceback.print_exc()
             log.warn("some graph dependencies may not have completely run")
-
-        if not self._initiallySynced.done():
-            self._initiallySynced.set_result(None)