view rdfdb/service.py @ 130:d195a5f50137

rename isEmpty to match js
author drewp@bigasterisk.com
date Mon, 29 May 2023 16:14:17 -0700
parents a71e4272d808
children 8fa6a47521d7
line wrap: on
line source

import functools
import itertools
import logging
from pathlib import Path
from typing import Dict, Optional, cast

from rdflib import URIRef
from starlette.applications import Starlette
from starlette.endpoints import WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.routing import Route, WebSocketRoute
from starlette.websockets import WebSocket
from starlette_exporter import PrometheusMiddleware, handle_metrics

from rdfdb.file_vs_uri import DirUriMap
from rdfdb.patch import Patch
from rdfdb.shared_graph import SharedGraph

log = logging.getLogger('rdfdb.net')

_wsClientSerial = itertools.count(0)


class SyncedGraphSocket(WebSocketEndpoint):
    """
    Send patches to the client (starting with a client who has 0
    statements) to keep it in sync with the graph.

    Accept patches from the client, and assume that the client has
    already applied them to its local graph.

    Treat a disconnect as 'out of sync'. Either the client thinks it
    is out of sync and wants to start over, or we can't apply a patch
    correctly therefore we disconnect to make the client start over.

    This socket may also carry some special messages meant for the
    rdfdb web UI, e.g. about who is connected, etc.
    """
    encoding = "text"

    def __init__(self, db: SharedGraph, *a, **kw):
        WebSocketEndpoint.__init__(self, *a, **kw)
        self.db = db
        self.connectionId = f'WS{next(_wsClientSerial)}'  # unneeded?
        log.info(f'ws connection init {self.connectionId}')

    def __repr__(self):
        return f'<WebSocket {self.connectionId}>'

    async def on_connect(self, websocket: WebSocket):
        await websocket.accept()
        await websocket.send_json({'connectedAs': self.connectionId})
        log.info(f"new ws client {self.connectionId}")

        await self.db.addClient(self.connectionId, functools.partial(self._onPatch, websocket))

    async def _onPatch(self, websocket: WebSocket, p: Patch):
        try:
            await websocket.send_text(p.makeJsonRepr())
        except RuntimeError:
            # likely WS disconnect
            log.warning("onPatch failed- hope the client calls back")

    async def on_receive(self, websocket: WebSocket, data: str):
        if data == 'PING':
            await websocket.send_text('PONG')
            return
        log.debug("%r sends patch to us: %s", self, data[:64])
        p = Patch(jsonRepr=data)

        # this is very important- I caught clients piling up dozens of retries, causing
        # us to rewrite files then notice them, etc. Problem may not be fully solved.
        if p.isEmpty():
            return

        try:
            await self.db.patch(p, sender=self.connectionId)
        except ValueError as e:
            log.warning(f'patch from {self!r} did not apply: {e!r}')
            # here we should disconnect that client and make them reset
            await websocket.close()

    async def on_disconnect(self, websocket, close_code):
        log.info("bye ws client %r: close_code=%s", self.connectionId, close_code)
        self.db.clientDisconnected(self.connectionId)


def get_graph(db: SharedGraph, request: Request) -> Response:
    accept = request.headers.get('accept', '')
    if accept == 'text/plain':
        format = 'nt'
    elif accept == 'application/n-quads':
        format = 'nquads'
    else:
        format = 'n3'
    return PlainTextResponse(db.graph.serialize(format=format))


async def post_prefixes(db: SharedGraph, request: Request) -> Response:
    suggestion = await request.json()
    db.addPrefixes(
        ctx=URIRef(suggestion['ctx']),  #
        prefixes=dict((cast(str, k), URIRef(v)) for k, v in suggestion['prefixes'].items()))
    return Response(status_code=204)


def makeApp(dirUriMap: Optional[DirUriMap] = None, prefixes: Optional[Dict[str, URIRef]] = None) -> Starlette:
    log.info('makeApp start')

    if dirUriMap is None:
        dirUriMap = {Path('data/'): URIRef('http://example.com/data/')}
    if prefixes is None:
        prefixes = {
            'rdf': URIRef('http://www.w3.org/1999/02/22-rdf-syntax-ns#'),
            'rdfs': URIRef('http://www.w3.org/2000/01/rdf-schema#'),
            'xsd': URIRef('http://www.w3.org/2001/XMLSchema#'),
        }

    db = SharedGraph(dirUriMap=dirUriMap, prefixes=prefixes)

    app = Starlette(
        debug=True,
        routes=[
            Route('/graph', functools.partial(get_graph, db)),
            WebSocketRoute('/syncedGraph', functools.partial(SyncedGraphSocket, db)),
            Route('/prefixes', functools.partial(post_prefixes, db)),
        ],
    )

    app.add_middleware(PrometheusMiddleware)
    app.add_route("/metrics", handle_metrics)

    return app