changeset 11:0bc06da6bf74

start ferry1 patch protocol
author drewp@bigasterisk.com
date Mon, 18 Mar 2024 16:42:21 -0700
parents 52e1bb1532f2
children ba73d8ba81dc
files examples/_run_server_child.py examples/serve_inline_graph_test.py pdm.lock pyproject.toml src/rdferry/patch/patch.py src/rdferry/patchablegraph.py src/rdferry/server.py
diffstat 7 files changed, 209 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/examples/_run_server_child.py	Sat Mar 16 16:02:23 2024 -0700
+++ b/examples/_run_server_child.py	Mon Mar 18 16:42:21 2024 -0700
@@ -1,8 +1,14 @@
 import asyncio
+import contextlib
+import logging
 from dataclasses import dataclass
+from datetime import timedelta
 from pathlib import Path
 
 import aiohttp
+from aiohttp_sse_client import client as sse_client
+
+log = logging.getLogger('chil')
 
 
 @dataclass
@@ -26,3 +32,19 @@
                 return await self._session.get(url, headers=headers)
             except aiohttp.ClientConnectorError:
                 await asyncio.sleep(0.05)
+
+    @contextlib.asynccontextmanager
+    async def eventSource(self, url: str):
+        async with sse_client.EventSource(
+                url, reconnection_time=timedelta(seconds=.05)) as es:
+            yield es
+
+
+async def assert_event_stream_starts_with(http_server, url, expected_events):
+    events_left = expected_events[:]
+    async with http_server.eventSource(url) as es:
+        async for event in es:
+            assert (event.message, event.data) == events_left[0]
+            events_left.pop(0)
+            if not events_left:
+                break
--- a/examples/serve_inline_graph_test.py	Sat Mar 16 16:02:23 2024 -0700
+++ b/examples/serve_inline_graph_test.py	Mon Mar 18 16:42:21 2024 -0700
@@ -1,9 +1,11 @@
 from pathlib import Path
-
+import logging
 import pytest
 
-from examples._run_server_child import RunHttpServerChildProcess
+from examples._run_server_child import RunHttpServerChildProcess, assert_event_stream_starts_with
 
+log = logging.getLogger('test')
+logging.basicConfig(level=logging.INFO)
 server_path = Path('examples/serve_inline_graph.py')
 
 
@@ -34,9 +36,15 @@
 '''
 
 
-# @pytest.mark.asyncio
-# async def test_server_returns_events():
-#     async with RunHttpServerChildProcess(server_path) as http_server:
-#         response = await http_server.get('http://localhost:8005/g1/events')
-#         assert response.headers['content-type'] == 'x-sse-todo'
-#         assert (await response.text()) == 'clear event then add-patch event'
+@pytest.mark.asyncio
+async def test_server_returns_startup_events():
+    async with RunHttpServerChildProcess(server_path) as http_server:
+        await assert_event_stream_starts_with(
+            http_server,
+            'http://localhost:8005/g1/events',
+            expected_events=[
+                ('clear', 'ferry1'),
+                ('patch',
+                 '-\n+\n["http://example.com/greeting", "http://www.w3.org/2000/01/rdf-schema#label", "hello world", "http://www.w3.org/2001/XMLSchema#string", "", "http://example.com/process"]'
+                 ),
+            ])
--- a/pdm.lock	Sat Mar 16 16:02:23 2024 -0700
+++ b/pdm.lock	Mon Mar 18 16:42:21 2024 -0700
@@ -5,14 +5,14 @@
 groups = ["default", "dev"]
 strategy = ["cross_platform", "inherit_metadata"]
 lock_version = "4.4.1"
-content_hash = "sha256:2dff41a22992283aea0d01ebc12301a09015f812422f1cb9be573613699a74f6"
+content_hash = "sha256:3ef7575dc740e12a788693e22dd6bef4cba46cbdc9ece6fb77b4b66881f1a589"
 
 [[package]]
 name = "aiohttp"
 version = "3.9.3"
 requires_python = ">=3.8"
 summary = "Async http client/server framework (asyncio)"
-groups = ["dev"]
+groups = ["default", "dev"]
 dependencies = [
     "aiosignal>=1.1.2",
     "attrs>=17.3.0",
@@ -55,11 +55,27 @@
 ]
 
 [[package]]
+name = "aiohttp-sse-client"
+version = "0.2.1"
+summary = "A Server-Sent Event python client base on aiohttp"
+groups = ["default"]
+dependencies = [
+    "aiohttp>=3",
+    "attrs",
+    "multidict",
+    "yarl",
+]
+files = [
+    {file = "aiohttp-sse-client-0.2.1.tar.gz", hash = "sha256:5004e29271624af586158dc7166cb0687a7a5997aab5b808f4b53400e1b72e3b"},
+    {file = "aiohttp_sse_client-0.2.1-py2.py3-none-any.whl", hash = "sha256:42c81ee9213e9fc8bc412b063bac3a813e02e75250c4c8049222234d41c9b024"},
+]
+
+[[package]]
 name = "aiosignal"
 version = "1.3.1"
 requires_python = ">=3.7"
 summary = "aiosignal: a list of registered asynchronous callbacks"
-groups = ["dev"]
+groups = ["default", "dev"]
 dependencies = [
     "frozenlist>=1.1.0",
 ]
@@ -88,7 +104,7 @@
 version = "23.2.0"
 requires_python = ">=3.7"
 summary = "Classes Without Boilerplate"
-groups = ["dev"]
+groups = ["default", "dev"]
 files = [
     {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
     {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
@@ -125,7 +141,7 @@
 version = "1.4.1"
 requires_python = ">=3.8"
 summary = "A list-like structure which implements collections.abc.MutableSequence"
-groups = ["dev"]
+groups = ["default", "dev"]
 files = [
     {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
     {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
@@ -212,7 +228,7 @@
 version = "6.0.5"
 requires_python = ">=3.7"
 summary = "multidict implementation"
-groups = ["dev"]
+groups = ["default", "dev"]
 files = [
     {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
     {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
@@ -375,6 +391,22 @@
 ]
 
 [[package]]
+name = "sse-starlette"
+version = "2.0.0"
+requires_python = ">=3.8"
+summary = "SSE plugin for Starlette"
+groups = ["default"]
+dependencies = [
+    "anyio",
+    "starlette",
+    "uvicorn",
+]
+files = [
+    {file = "sse_starlette-2.0.0-py3-none-any.whl", hash = "sha256:c4dd134302cb9708d47cae23c365fe0a089aa2a875d2f887ac80f235a9ee5744"},
+    {file = "sse_starlette-2.0.0.tar.gz", hash = "sha256:0c43cc43aca4884c88c8416b65777c4de874cc4773e6458d3579c0a353dc2fb7"},
+]
+
+[[package]]
 name = "starlette"
 version = "0.37.2"
 requires_python = ">=3.8"
@@ -437,7 +469,7 @@
 version = "1.9.4"
 requires_python = ">=3.7"
 summary = "Yet another URL library"
-groups = ["dev"]
+groups = ["default", "dev"]
 dependencies = [
     "idna>=2.0",
     "multidict>=4.0",
--- a/pyproject.toml	Sat Mar 16 16:02:23 2024 -0700
+++ b/pyproject.toml	Mon Mar 18 16:42:21 2024 -0700
@@ -10,6 +10,8 @@
     "uvicorn>=0.28.0",
     "starlette>=0.37.2",
     "prometheus-client>=0.20.0",
+    "sse-starlette>=2.0.0",
+    "aiohttp-sse-client>=0.2.1",
 ]
 requires-python = ">=3.11"
 readme = "README.md"
--- a/src/rdferry/patch/patch.py	Sat Mar 16 16:02:23 2024 -0700
+++ b/src/rdferry/patch/patch.py	Mon Mar 18 16:42:21 2024 -0700
@@ -1,6 +1,8 @@
 from dataclasses import dataclass, field
 from typing import Collection
 
+from rdflib import ConjunctiveGraph
+
 from rdflib.graph import _QuadType as Quad
 
 
@@ -14,3 +16,15 @@
     """
     dels: Collection[Quad] = field(default_factory=set, hash=True)
     adds: Collection[Quad] = field(default_factory=set, hash=True)
+
+    def delsGraph(self) -> ConjunctiveGraph:
+        return self._toGraph(self.dels)
+
+    def addsGraph(self) -> ConjunctiveGraph:
+        return self._toGraph(self.adds)
+
+    def _toGraph(self, quads: Collection[Quad]) -> ConjunctiveGraph:
+        g = ConjunctiveGraph()
+        g.addN(quads)
+        return g
+
--- a/src/rdferry/patchablegraph.py	Sat Mar 16 16:02:23 2024 -0700
+++ b/src/rdferry/patchablegraph.py	Mon Mar 18 16:42:21 2024 -0700
@@ -1,9 +1,11 @@
 import asyncio
 import itertools
+import json
+from typing import NewType
 import weakref
 from rdferry.patch_quads import patchQuads
 from rdferry.rdflib_issues.contains_with_context_398 import inGraph
-from rdflib import ConjunctiveGraph
+from rdflib import ConjunctiveGraph, Graph, URIRef
 import logging
 from rdferry.patch.patch import Patch
 from prometheus_client import Counter, Gauge, Summary
@@ -15,8 +17,18 @@
 PATCH_CALLS = Summary('patch_calls',
                       'PatchableGraph.patch calls',
                       labelnames=['graph'])
+OBSERVERS_CURRENT = Gauge('observers_current',
+                          'current observer count',
+                          labelnames=['graph'])
+OBSERVERS_ADDED = Counter('observers_added',
+                          'observers added',
+                          labelnames=['graph'])
 _graphsInProcess = itertools.count()
 
+# Message type and data string to be sent to all listening SSE clients.
+SseEvent = NewType('SseEvent', tuple[str, str])
+SseEventQueue = asyncio.Queue[SseEvent]
+
 
 class PatchableGraph:
     """
@@ -26,7 +38,7 @@
 
     def __init__(self, label: str | None = None):
         self._graph = ConjunctiveGraph()
-        # self._subscriptions: weakref.WeakSet[asyncio.Queue] = weakref.WeakSet()
+        self._subscriptions: weakref.WeakSet[SseEventQueue] = weakref.WeakSet()
 
         if label is None:
             label = f'patchableGraph{next(_graphsInProcess)}'
@@ -45,11 +57,54 @@
                        deleteQuads=dels,
                        addQuads=adds,
                        perfect=False)  # true?
-            # if self._subscriptions:
-            #     log.debug('PatchableGraph: patched; telling %s observers',
-            #               len(self._subscriptions))
-            # j = patchAsJson(p)
-            # for q in self._subscriptions:
             #     q.put_nowait(('patch', j))
             STATEMENT_COUNT.labels(graph=self.label).set(len(self._graph))
 
+    def subscribeToPatches(self) -> SseEventQueue:
+        q = SseEventQueue()
+        qref = weakref.ref(q, self._onUnsubscribe)
+        self._initialSubscribeEvents(qref)
+        return q
+
+    def _initialSubscribeEvents(self, qref: weakref.ref[SseEventQueue]):
+        q = qref()
+        if q is None:
+            raise TypeError
+        log.info('new sub queue %s', q)
+        self._subscriptions.add(
+            q)  # when caller forgets about queue, we will too
+        OBSERVERS_CURRENT.labels(graph=self.label).set(len(
+            self._subscriptions))
+        OBSERVERS_ADDED.labels(graph=self.label).inc()
+        q.put_nowait(clearEvent())
+        q.put_nowait(patchEvent(addWholeGraphPatch(self._graph)))
+
+    def _onUnsubscribe(self, qref: weakref.ref[SseEventQueue]):
+        log.info("bye sub", qref)
+        OBSERVERS_CURRENT.labels(graph=self.label).set(len(
+            self._subscriptions))  # minus one?
+
+def clearEvent() -> SseEvent:
+    return SseEvent(('clear', 'ferry1'))
+
+
+
+
+def quadsWithGraphContexts(quads):
+    for s, p, o, c in quads:
+        if isinstance(c, URIRef):
+            c = Graph(identifier=c)
+        if not isinstance(c, Graph):
+            raise TypeError("bad quad context type in %r" % ((s, p, o, c), ))
+        yield s, p, o, c
+
+def addWholeGraphPatch(graph: ConjunctiveGraph) -> Patch:
+    """a patch that adds every quad in the graph"""
+    return Patch(adds=list(quadsWithGraphContexts(graph.quads())))
+
+
+def patchEvent(p: Patch) -> SseEvent:
+    return SseEvent((
+        'patch',  #
+        ('-\n' + p.delsGraph().serialize(format='hext')) +
+        ('+\n' + p.addsGraph().serialize(format='hext'))))
--- a/src/rdferry/server.py	Sat Mar 16 16:02:23 2024 -0700
+++ b/src/rdferry/server.py	Mon Mar 18 16:42:21 2024 -0700
@@ -1,13 +1,20 @@
+import logging
 from functools import partial
+from typing import Awaitable, Callable
 
 import uvicorn
-from rdflib import plugin
-from rdflib.serializer import Serializer
 from starlette.applications import Starlette
 from starlette.requests import Request
-from starlette.responses import PlainTextResponse
-
+from starlette.responses import PlainTextResponse, Response
+from sse_starlette.sse import EventSourceResponse
 from rdferry.patchablegraph import PatchableGraph
+from sse_starlette import ServerSentEvent
+from prometheus_client import Summary
+log = logging.getLogger('serv')
+SEND_SIMPLE_GRAPH = Summary('send_simple_graph',
+                            'calls to _writeGraphResponse')
+SEND_CLEAR = Summary('send_full_graph', 'fullGraph SSE events')
+SEND_PATCH = Summary('send_patch', 'patch SSE events')
 
 
 class StarletteServer:
@@ -16,7 +23,9 @@
         self.app = Starlette()
         self.root_route_is_set = False
 
-    def add_route(self, path: str, route, **kw):
+    def add_route(self, path: str,
+                  route: Callable[[Request],
+                                  Awaitable[Response] | Response], **kw):
         self.app.add_route(path, route, **kw)
         if path == '/':
             self.root_route_is_set = True
@@ -24,16 +33,55 @@
     def add_graph_routes(self, path: str, graph: PatchableGraph):
         """Adds {path} and {path}/events"""
         self.add_route(path, partial(self._on_graph_request, graph))
+        self.add_route(path + '/events',
+                       self._graph_events_request_handler(graph))
 
     def _on_graph_request(self, graph: PatchableGraph,
                           request: Request) -> PlainTextResponse:
-        format = request.headers.get('Accept', '*/*') 
+        format = request.headers.get('Accept', '*/*')
         if format == '*/*':
             format = 'application/trig'
         return PlainTextResponse(content=graph._graph.serialize(format=format),
                                  media_type=format)
 
+    def _graph_events_request_handler(
+        self,
+        graph: PatchableGraph,
+    ) -> Callable[[Request], EventSourceResponse]:
+        return _GraphEvents(graph)
+
     def serve(self):
         if not self.root_route_is_set:
             self.add_route('/', lambda req: PlainTextResponse('todo'))
         uvicorn.run(self.app, host="0.0.0.0", port=8005)
+
+
+def _GraphEvents(
+        masterGraph: PatchableGraph
+) -> Callable[[Request], EventSourceResponse]:
+
+    async def generateEvents():
+        events = masterGraph.subscribeToPatches()
+        while True:  # we'll get cancelled by EventSourceResponse when the conn drops
+            etype, data = await events.get()
+            # Are there more to get? We might throttle and combine patches here- ideally we could see how
+            # long the latency to the client is to make a better rate choice
+            metric = {'clear': SEND_CLEAR, 'patch': SEND_PATCH}[etype]
+            with metric.time():
+                log.warning(f'yielding {etype=}  {data=} event')
+                yield ServerSentEvent(event=etype, data=data)
+
+    def handle(request: Request) -> EventSourceResponse:
+        """
+        One session with one client.
+
+        returns current graph plus future patches to keep remote version
+        in sync with ours.
+
+        instead of turning off buffering all over, it may work for this
+        response to send 'x-accel-buffering: no', per
+        http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_buffering
+        """
+        return EventSourceResponse(generateEvents())
+
+    return handle