view rdfdb/service.py @ 109:bc643d61bb7c

format & comments
author drewp@bigasterisk.com
date Mon, 30 May 2022 22:55:20 -0700
parents 19100db34354
children 3733efe1fd19
line wrap: on
line source

import functools
import itertools
import logging
from pathlib import Path
from typing import Dict, Optional, cast

from rdflib import URIRef
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocket
from starlette_exporter import PrometheusMiddleware, handle_metrics

from rdfdb.file_vs_uri import DirUriMap
from rdfdb.patch import Patch
from rdfdb.shared_graph import SharedGraph

log = logging.getLogger('rdfdb')

_wsClientSerial = itertools.count(0)


class SyncedGraphSocket(WebSocketEndpoint):
    """
    Send patches to the client (starting with a client who has 0
    statements) to keep it in sync with the graph.

    Accept patches from the client, and assume that the client has
    already applied them to its local graph.

    Treat a disconnect as 'out of sync'. Either the client thinks it
    is out of sync and wants to start over, or we can't apply a patch
    correctly therefore we disconnect to make the client start over.

    This socket may also carry some special messages meant for the
    rdfdb web UI, e.g. about who is connected, etc.
    """
    encoding = "text"

    def __init__(self, db: SharedGraph, *a, **kw):
        WebSocketEndpoint.__init__(self, *a, **kw)
        self.db = db
        self.connectionId = f'WS{next(_wsClientSerial)}'  # unneeded?
        log.info(f'ws connection init {self.connectionId}')

    def __repr__(self):
        return f'<WebSocket {self.connectionId}>'

    async def on_connect(self, websocket: WebSocket):
        await websocket.accept()
        await websocket.send_json({'connectedAs': self.connectionId})
        log.info(f"new ws client {self.connectionId}")

        await self.db.addClient(self.connectionId, functools.partial(self._onPatch, websocket))

    async def _onPatch(self, websocket: WebSocket, p: Patch):
        await websocket.send_text(p.makeJsonRepr())

    async def on_receive(self, websocket: WebSocket, data: str):
        if data == 'PING':
            await websocket.send_text('PONG')
            return
        log.debug("%r sends patch to us: %s", self, data[:64])
        p = Patch(jsonRepr=data)
        try:
            await self.db.patch(p, sender=self.connectionId)
        except ValueError as e:
            log.warning(f'patch from {self!r} did not apply: {e!r}')
            # here we should disconnect that client and make them reset
            await websocket.close()

    async def on_disconnect(self, websocket, close_code):
        log.info("bye ws client %r: %s", self.connectionId, close_code)
        self.db.clientDisconnected(self.connectionId)


def get_graph(db: SharedGraph, request: Request) -> Response:
    accept = request.headers.get('accept', '')
    if accept == 'text/plain':
        format = 'nt'
    elif accept == 'application/n-quads':
        format = 'nquads'
    else:
        format = 'n3'
    return PlainTextResponse(db.graph.serialize(format=format))


async def post_prefixes(db: SharedGraph, request: Request) -> Response:
    suggestion = await request.json()
    db.addPrefixes(
        ctx=URIRef(suggestion['ctx']),  #
        prefixes=dict((cast(str, k), URIRef(v)) for k, v in suggestion['prefixes'].items()))
    return Response(status_code=204)


def makeApp(dirUriMap: Optional[DirUriMap] = None, prefixes: Optional[Dict[str, URIRef]] = None) -> Starlette:
    log.info('makeApp start')

    if dirUriMap is None:
        dirUriMap = {Path('data/'): URIRef('http://example.com/data/')}
    if prefixes is None:
        prefixes = {
            'rdf': URIRef('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
            'rdfs': URIRef('http://www.w3.org/2000/01/rdf-schema#'),
            'xsd': URIRef('http://www.w3.org/2001/XMLSchema#'),
        }

    log.setLevel(logging.DEBUG if 1 else logging.INFO)

    log.info('setup watches')
    db = SharedGraph(dirUriMap=dirUriMap, prefixes=prefixes)

    app = Starlette(
        debug=True,
        routes=[
            Route('/graph', functools.partial(get_graph, db)),
            WebSocketRoute('/syncedGraph', functools.partial(SyncedGraphSocket, db)),
            Route('/prefixes', functools.partial(post_prefixes, db)),
        ],
    )

    app.add_middleware(PrometheusMiddleware)
    app.add_route("/metrics", handle_metrics)

    return app