view patchsource.py @ 4:dc4f852d0d70

reformat and add some types
author drewp@bigasterisk.com
date Wed, 24 Nov 2021 19:47:35 -0800
parents c3f0a692c4cb
children
line wrap: on
line source

import logging
import time
import traceback
from typing import Dict, Optional, Protocol

from rdfdb.patch import Patch
from rdflib import ConjunctiveGraph, URIRef
from rdflib.parser import StringInputSource
from twisted.internet import defer, reactor
from twisted_sse.eventsource import EventSource

from .patchablegraph import patchFromJson

log = logging.getLogger('fetch')


class _Listener(Protocol):
    def __call__(
            self,
            p: Patch,
            fullGraph: bool,  # True if the  patch is the initial full graph.
    ) -> None:
        ...


class PatchSource(object):
    """wrap EventSource so it emits Patch objects and has an explicit stop method."""
    def __init__(self, url: str, agent: str):
        self.url = url

        # add callbacks to these to learn if we failed to connect
        # (approximately) or if the ccnnection was unexpectedly lost
        self.connectionFailed = defer.Deferred()
        self.connectionLost = defer.Deferred()

        self._listeners = set()
        log.info('start read from %s', url)
        self._startReadTime = time.time()
        self._patchesReceived = 0  # including fullgraph
        # note: fullGraphReceived isn't guaranteed- the stream could
        # start with patches
        self._fullGraphReceived = False
        self._eventSource: Optional[EventSource] = EventSource(
            url.encode('utf8'), userAgent=agent)

        self._eventSource.addEventListener(b'fullGraph', self._onFullGraph)
        self._eventSource.addEventListener(b'patch', self._onPatch)
        self._eventSource.onerror(self._onError)
        self._eventSource.onConnectionLost = self._onDisconnect

    def state(self) -> Dict:
        return {
            'url': self.url,
            'fullGraphReceived': self._fullGraphReceived,
            'patchesReceived': self._patchesReceived,
            'time': {
                'open': getattr(self, '_startReadTime', None),
                'fullGraph': getattr(self, '_fullGraphTime', None),
                'latestPatch': getattr(self, '_latestPatchTime', None),
            },
            'closed': self._eventSource is None,
        }

    def addPatchListener(self, func: _Listener):
        """
        func(patch, fullGraph=[true if the patch is the initial fullgraph])
        """
        self._listeners.add(func)

    def stop(self):
        log.info('stop read from %s', self.url)
        if self._eventSource is not None:
            self._eventSource.protocol.stopProducing()  # needed?
        self._eventSource = None

    def _onDisconnect(self, reason):
        log.debug('PatchSource._onDisconnect from %s (%s)', self.url, reason)
        # skip this if we're doing a stop?
        self.connectionLost.callback(None)

    def _onError(self, msg):
        log.debug('PatchSource._onError from %s %r', self.url, msg)
        if not self._fullGraphReceived:
            self.connectionFailed.callback(msg)
        else:
            self.connectionLost.callback(msg)

    def _onFullGraph(self, message: str):
        try:
            g = ConjunctiveGraph()
            g.parse(StringInputSource(message), format='json-ld')
            p = Patch(addGraph=g)
            self._sendPatch(p, fullGraph=True)
        except Exception:
            log.error(traceback.format_exc())
            raise
        self._fullGraphReceived = True
        self._fullGraphTime = time.time()
        self._patchesReceived += 1

    def _onPatch(self, message: str):
        try:
            p = patchFromJson(message)
            self._sendPatch(p, fullGraph=False)
        except:
            log.error(traceback.format_exc())
            raise
        self._latestPatchTime = time.time()
        self._patchesReceived += 1

    def _sendPatch(self, p: Patch, fullGraph: bool):
        log.debug('PatchSource %s received patch %s (fullGraph=%s)', self.url,
                  p.shortSummary(), fullGraph)
        for lis in self._listeners:
            lis(p, fullGraph=fullGraph)

    def __del__(self):
        if self._eventSource:
            raise ValueError("PatchSource wasn't stopped before del")


class ReconnectingPatchSource(object):
    """
    PatchSource api, but auto-reconnects internally and takes listener
    at init time to not miss any patches. You'll get another
    fullGraph=True patch if we have to reconnect.

    todo: generate connection stmts in here
    """
    def __init__(self,
                 url: str,
                 listener: _Listener,
                 reconnectSecs=60,
                 agent='unset'):
        self.url = url
        self._stopped = False
        self._listener = listener
        self.reconnectSecs = reconnectSecs
        self.agent = agent
        self._reconnect()

    def _reconnect(self):
        if self._stopped:
            return
        self._ps = PatchSource(self.url, agent=self.agent)
        self._ps.addPatchListener(self._onPatch)
        self._ps.connectionFailed.addCallback(self._onConnectionFailed)
        self._ps.connectionLost.addCallback(self._onConnectionLost)

    def _onPatch(self, p, fullGraph):
        self._listener(p, fullGraph=fullGraph)

    def state(self):
        return {
            'reconnectedPatchSource': self._ps.state(),
        }

    def stop(self):
        self._stopped = True
        self._ps.stop()

    def _onConnectionFailed(self, arg):
        reactor.callLater(self.reconnectSecs, self._reconnect)

    def _onConnectionLost(self, arg):
        reactor.callLater(self.reconnectSecs, self._reconnect)