changeset 1106:fe53ca09febc

big rewrites in sse_collector Ignore-this: 3b6278a0cfc57aa686ed39d411fdc35f darcs-hash:d25124b5e0d3c4729ea55530cd3b3064f2af68a7
author drewp <drewp@bigasterisk.com>
date Sun, 28 Aug 2016 18:11:34 -0700
parents c8233f4b59cb
children e68f6e5712c6
files service/reasoning/sse_collector.py service/reasoning/twisted_sse_demo/eventsource.py
diffstat 2 files changed, 264 insertions(+), 128 deletions(-) [+]
line wrap: on
line diff
--- a/service/reasoning/sse_collector.py	Sat Aug 20 23:34:04 2016 -0700
+++ b/service/reasoning/sse_collector.py	Sun Aug 28 18:11:34 2016 -0700
@@ -23,7 +23,7 @@
 no_setup()
 
 import sys, logging, traceback, json, collections
-from twisted.internet import reactor
+from twisted.internet import reactor, defer
 import cyclone.web, cyclone.sse
 from rdflib import ConjunctiveGraph, URIRef, Namespace
 from rdflib.parser import StringInputSource
@@ -41,47 +41,35 @@
 ROOM = Namespace("http://projects.bigasterisk.com/room/")
 COLLECTOR = URIRef('http://bigasterisk.com/sse_collector/')
 
-class ConnectionLost(object):
-    pass
-
 class PatchSource(object):
     """wrap EventSource so it emits Patch objects and has an explicit stop method."""
     def __init__(self, url):
         self.url = url
+
+        # add callbacks to these to learn if we failed to connect
+        # (approximately) or if the ccnnection was unexpectedly lost
+        self.connectionFailed = defer.Deferred()
+        self.connectionLost = defer.Deferred()
+        
         self._listeners = set()
         log.info('start read from %s', url)
+        self._fullGraphReceived = False
         self._eventSource = EventSource(url.toPython().encode('utf8'))
         self._eventSource.protocol.delimiter = '\n'
 
         self._eventSource.addEventListener('fullGraph', self._onFullGraph)
-        self._eventSource.addEventListener('patch', self._onMessage)
-
-    def _onFullGraph(self, message):
-        try:
-            g = ConjunctiveGraph()
-            g.parse(StringInputSource(message), format='json-ld')
-            p = Patch(addGraph=g)
-            self._sendPatch(p, fullGraph=True)
-        except:
-            log.error(traceback.format_exc())
-            raise
-            
-    def _onMessage(self, message):
-        try:
-            p = patchFromJson(message)
-            self._sendPatch(p, fullGraph=False)
-        except:
-            log.error(traceback.format_exc())
-            raise
-
-    def _sendPatch(self, p, fullGraph):
-        log.debug('PatchSource received patch %s', p.shortSummary())
-        for lis in self._listeners:
-            lis(p, fullGraph=fullGraph)
+        self._eventSource.addEventListener('patch', self._onPatch)
+        self._eventSource.onerror(self._onError)
+        
+        origSet = self._eventSource.protocol.setFinishedDeferred
+        def sfd(d):
+            origSet(d)
+            d.addCallback(self._onDisconnect)
+        self._eventSource.protocol.setFinishedDeferred = sfd
         
     def addPatchListener(self, func):
         """
-        func(patch or ConnectionLost, fullGraph=[true if the patch is the initial fullgraph])
+        func(patch, fullGraph=[true if the patch is the initial fullgraph])
         """
         self._listeners.add(func)
 
@@ -93,11 +81,85 @@
             pass
         self._eventSource = None
 
+    def _onDisconnect(self, a):
+        log.debug('PatchSource._onDisconnect from %s', self.url)
+        # skip this if we're doing a stop?
+        self.connectionLost.callback(None)
+
+    def _onError(self, msg):
+        log.debug('PatchSource._onError from %s %r', self.url, msg)
+        if not self._fullGraphReceived:
+            self.connectionFailed.callback(msg)
+        else:
+            self.connectionLost.callback(msg)
+
+    def _onFullGraph(self, message):
+        try:
+            g = ConjunctiveGraph()
+            g.parse(StringInputSource(message), format='json-ld')
+            p = Patch(addGraph=g)
+            self._sendPatch(p, fullGraph=True)
+        except:
+            log.error(traceback.format_exc())
+            raise
+        self._fullGraphReceived = True
+            
+    def _onPatch(self, message):
+        try:
+            p = patchFromJson(message)
+            self._sendPatch(p, fullGraph=False)
+        except:
+            log.error(traceback.format_exc())
+            raise
+
+    def _sendPatch(self, p, fullGraph):
+        log.debug('PatchSource %s received patch %s (fullGraph=%s)', self.url, p.shortSummary(), fullGraph)
+        for lis in self._listeners:
+            lis(p, fullGraph=fullGraph)
+        
     def __del__(self):
         if self._eventSource:
             raise ValueError
 
+class ReconnectingPatchSource(object):
+    """
+    PatchSource api, but auto-reconnects internally and takes listener
+    at init time to not miss any patches. You'll get another
+    fullGraph=True patch if we have to reconnect.
+
+    todo: generate connection stmts in here
+    """
+    def __init__(self, url, listener):
+        self.url = url
+        self._stopped = False
+        self._listener = listener
+        self._reconnect()
+
+    def _reconnect(self):
+        if self._stopped:
+            return
+        self._ps = PatchSource(self.url)
+        self._ps.addPatchListener(self._onPatch)
+        self._ps.connectionFailed.addCallback(self._onConnectionFailed)
+        self._ps.connectionLost.addCallback(self._onConnectionLost)        
+
+    def _onPatch(self, p, fullGraph):
+        self._listener(p, fullGraph=fullGraph)
+        
+    def stop(self):
+        self._stopped = True
+        self._ps.stop()
+        
+    def _onConnectionFailed(self, arg):
+        reactor.callLater(1, self._reconnect)
+        
+    def _onConnectionLost(self, arg):
+        reactor.callLater(1, self._reconnect)        
+            
 class LocalStatements(object):
+    """
+    functions that make statements originating from sse_collector itself
+    """
     def __init__(self, applyPatch):
         self.applyPatch = applyPatch
         self._sourceState = {} # source: state URIRef
@@ -132,152 +194,226 @@
                 delQuads=[
                     (source, ROOM['state'], oldState, COLLECTOR),
                 ]))
-            
-class GraphClients(object):
-    """
-    All the active GraphClient objects
 
-    To handle all the overlapping-statement cases, we store a set of
-    true statements along with the sources that are currently
-    asserting them and the requesters who currently know them. As
-    statements come and go, we make patches to send to requesters.
+def abbrevTerm(t):
+    if isinstance(t, URIRef):
+        return (t.replace('http://projects.bigasterisk.com/room/', 'room:')
+                .replace('http://bigasterisk.com/sse_collector/', 'sc:'))
+    return t
+
+def abbrevStmt(stmt):
+    return '(%s %s %s %s)' % tuple(map(abbrevTerm, stmt))
     
-    todo: reconnect patchsources that go down and deal with their graph diffs
-    """
+class ActiveStatements(object):
     def __init__(self):
-        self.clients = {}  # url: PatchSource
-        self.handlers = set()  # handler
-        self.listeners = {}  # url: [handler]  (handler may appear under multiple urls)
 
         # This table holds statements asserted by any of our sources
         # plus local statements that we introduce (source is
         # http://bigasterisk.com/sse_collector/).
         self.statements = collections.defaultdict(lambda: (set(), set())) # (s,p,o,c): (sourceUrls, handlers)`
-
-        self._localStatements = LocalStatements(self._onPatch)
-
-    def _pprintTable(self):
+    
+    def _postDeleteStatements(self):
+        statements = self.statements
+        class PostDeleter(object):
+            def __enter__(self):
+                self._garbage = []
+                return self
+            def add(self, stmt):
+                self._garbage.append(stmt)
+            def __exit__(self, type, value, traceback):
+                if type is not None:
+                    raise
+                for stmt in self._garbage:
+                    del statements[stmt]
+        return PostDeleter()
+        
+    def pprintTable(self):
         for i, (stmt, (sources, handlers)) in enumerate(sorted(self.statements.items())):
-            print "%03d. (%s, %s, %s, %s) from %s to %s" % (
-                i,
-                stmt[0].n3(),
-                stmt[1].n3(),
-                stmt[2].n3(),
-                stmt[3].n3(),
-                ','.join(s.n3() for s in sources),
-                handlers)        
-            
-    def _sendUpdatePatch(self, handler):
-        """send a patch event out this handler to bring it up to date with self.statements"""
-        p = self._makeSyncPatch(handler)
-        if not p.isNoop():
-            log.debug("send patch %s to %s", p.shortSummary(), handler)
-            handler.sendEvent(message=jsonFromPatch(p), event='patch')
+            print "%03d. %-80s from %s to %s" % (
+                i, abbrevStmt(stmt), [abbrevTerm(s) for s in sources], handlers)        
 
-    def _makeSyncPatch(self, handler):
+    def makeSyncPatch(self, handler, sources):
         # todo: this could run all handlers at once, which is how we use it anyway
         adds = []
         dels = []
-        statementsToClear = []
-        for stmt, (sources, handlers) in self.statements.iteritems():
-            relevantToHandler = handler in sum((self.listeners.get(s, []) for s in sources), [])
-            handlerHasIt = handler in handlers
-            if relevantToHandler and not handlerHasIt:
-                adds.append(stmt)
-                handlers.add(handler)
-            elif not relevantToHandler and handlerHasIt:
-                dels.append(stmt)
-                handlers.remove(handler)
-                if not handlers:
-                    statementsToClear.append(stmt)
-                    
-        for stmt in statementsToClear:
-            del self.statements[stmt]
+        
+        with self._postDeleteStatements() as garbage:
+            for stmt, (stmtSources, handlers) in self.statements.iteritems():
+                belongsInHandler = not set(sources).isdisjoint(stmtSources)
+                handlerHasIt = handler in handlers
+                #log.debug("%s %s %s", abbrevStmt(stmt), belongsInHandler, handlerHasIt)
+                if belongsInHandler and not handlerHasIt:
+                    adds.append(stmt)
+                    handlers.add(handler)
+                elif not belongsInHandler and handlerHasIt:
+                    dels.append(stmt)
+                    handlers.remove(handler)
+                    if not handlers and not stmtSources:
+                        garbage.add(stmt)
 
         return Patch(addQuads=adds, delQuads=dels)
         
-    def _onPatch(self, source, p, fullGraph=False):
+    def applySourcePatch(self, source, p):
         for stmt in p.addQuads:
             sourceUrls, handlers = self.statements[stmt]
             if source in sourceUrls:
-                raise ValueError("%s added stmt that it already had: %s" % (source, stmt))
+                raise ValueError("%s added stmt that it already had: %s" %
+                                 (source, abbrevStmt(stmt)))
             sourceUrls.add(source)
-        for stmt in p.delQuads:
-            sourceUrls, handlers = self.statements[stmt]
-            if source not in sourceUrls:
-                raise ValueError("%s deleting stmt that it didn't have: %s" % (source, stmt))
-            sourceUrls.remove(source)
+            
+        with self._postDeleteStatements() as garbage:
+            for stmt in p.delQuads:
+                sourceUrls, handlers = self.statements[stmt]
+                if source not in sourceUrls:
+                    raise ValueError("%s deleting stmt that it didn't have: %s" %
+                                     (source, abbrevStmt(stmt)))
+                sourceUrls.remove(source)
+                # this is rare, since some handler probably still has
+                # the stmt we're deleting, but it can happen e.g. when
+                # a handler was just deleted
+                if not sourceUrls and not handlers:
+                    garbage.add(stmt)
 
-        for h in self.handlers:
-            self._sendUpdatePatch(h)
+    def replaceSourceStatements(self, source, stmts):
+        log.debug('replaceSourceStatements with %s stmts', len(stmts))
+        newStmts = set(stmts)
+
+        with self._postDeleteStatements() as garbage:
+            for stmt, (sources, handlers) in self.statements.iteritems():
+                if source in sources:
+                    if stmt not in stmts:
+                        sources.remove(source)
+                        if not sources and not handlers:
+                            garbage.add(stmt)
+                else:
+                    if stmt in stmts:
+                        sources.add(source)
+                newStmts.discard(stmt)
 
-        if log.isEnabledFor(logging.DEBUG):
-            self._pprintTable()
+        self.applySourcePatch(source, Patch(addQuads=newStmts, delQuads=[]))
+
+    def discardHandler(self, handler):
+        with self._postDeleteStatements() as garbage:
+            for stmt, (sources, handlers) in self.statements.iteritems():
+                handlers.discard(handler)
+                if not sources and not handlers:
+                    garbage.add(stmt)
 
-        if source != COLLECTOR:
-            if fullGraph:
-                self._localStatements.setSourceState(source, ROOM['fullGraphReceived'])
-            else:
-                self._localStatements.setSourceState(source, ROOM['patchesReceived'])
+    def discardSource(self, source):
+        with self._postDeleteStatements() as garbage:
+            for stmt, (sources, handlers) in self.statements.iteritems():
+                sources.discard(source)
+                if not sources and not handlers:
+                    garbage.add(stmt)
+                    
+class GraphClients(object):
+    """
+    All the active PatchSources and SSEHandlers
+
+    To handle all the overlapping-statement cases, we store a set of
+    true statements along with the sources that are currently
+    asserting them and the requesters who currently know them. As
+    statements come and go, we make patches to send to requesters.
+    """
+    def __init__(self):
+        self.clients = {}  # url: PatchSource (COLLECTOR is not listed)
+        self.handlers = set()  # handler
+        self.statements = ActiveStatements()
         
-    def addSseHandler(self, handler, streamId):
-        log.info('addSseHandler %r %r', handler, streamId)
+        self._localStatements = LocalStatements(self._onPatch)
+
+    def _sourcesForHandler(self, handler):
+        streamId = handler.streamId
         matches = [s for s in config['streams'] if s['id'] == streamId]
         if len(matches) != 1:
             raise ValueError("%s matches for %r" % (len(matches), streamId))
+        return map(URIRef, matches[0]['sources']) + [COLLECTOR]
+        
+    def _onPatch(self, source, p, fullGraph=False):
+        if fullGraph:
+            # a reconnect may need to resend the full graph even
+            # though we've already sent some statements
+            self.statements.replaceSourceStatements(source, p.addQuads)
+        else:
+            self.statements.applySourcePatch(source, p)
 
+        self._sendUpdatePatch()
+
+        if log.isEnabledFor(logging.DEBUG):
+            self.statements.pprintTable()
+
+        if source != COLLECTOR:
+            self._localStatements.setSourceState(
+                source,
+                ROOM['fullGraphReceived'] if fullGraph else
+                ROOM['patchesReceived'])
+
+    def _sendUpdatePatch(self, handler=None):
+        """
+        send a patch event out this handler to bring it up to date with
+        self.statements
+        """
+        # reduce loops here- prepare all patches at once
+        for h in (self.handlers if handler is None else [handler]):
+            p = self.statements.makeSyncPatch(h, self._sourcesForHandler(h))
+            if not p.isNoop():
+                log.debug("send patch %s to %s", p.shortSummary(), h)
+                h.sendEvent(message=jsonFromPatch(p), event='patch')
+        
+    def addSseHandler(self, handler):
+        log.info('addSseHandler %r %r', handler, handler.streamId)
         self.handlers.add(handler)
-        for source in map(URIRef, matches[0]['sources']):
-            if source not in self.clients:
+        
+        for source in self._sourcesForHandler(handler):
+            if source not in self.clients and source != COLLECTOR:
                 self._localStatements.setSourceState(source, ROOM['connect'])
-                ps = self.clients[source] = PatchSource(source)
-                ps.addPatchListener(
-                    lambda p, fullGraph, source=source: self._onPatch(source, p, fullGraph))
-            self.listeners.setdefault(source, []).append(handler)
+                ps = self.clients[source] = ReconnectingPatchSource(
+                    source, listener=lambda p, fullGraph, source=source: self._onPatch(
+                        source, p, fullGraph))
         self._sendUpdatePatch(handler)
         
     def removeSseHandler(self, handler):
         log.info('removeSseHandler %r', handler)
-        
-        statementsToClear = []
-        for stmt, (sources, handlers) in self.statements.iteritems():
-            handlers.discard(handler)
-            if not sources and not handlers:
-                statementsToClear.append(stmt)
-        for stmt in statementsToClear:
-            del self.statements[stmt]
-                
-        for url, handlers in self.listeners.items():
-            keep = []
-            for h in handlers:
-                if h != handler:
-                    keep.append(h)
-            handlers[:] = keep
-            if not keep:
-                self._stopClient(url)
+
+        self.statements.discardHandler(handler)
+
+        for source in self._sourcesForHandler(handler):
+            for otherHandler in self.handlers:
+                if (otherHandler != handler and
+                    source in self._sourcesForHandler(otherHandler)):
+                    break
+            else:
+                self._stopClient(source)
+            
         self.handlers.remove(handler)
 
     def _stopClient(self, url):
+        if url == COLLECTOR:
+            return
+            
         self.clients[url].stop()
 
-        for stmt, (sources, handlers) in self.statements.iteritems():
-            sources.discard(url)
+        self.statements.discardSource(url)
         
         self._localStatements.setSourceState(url, None)
         del self.clients[url]
-        del self.listeners[url]
+        
 
-        
-        
-        
 class SomeGraph(cyclone.sse.SSEHandler):
+    _handlerSerial = 0
     def __init__(self, application, request):
         cyclone.sse.SSEHandler.__init__(self, application, request)
-        self.id = request.uri[len('/graph/'):]
+        self.streamId = request.uri[len('/graph/'):]
         self.graphClients = self.settings.graphClients
         
+        self._serial = SomeGraph._handlerSerial
+        SomeGraph._handlerSerial += 1
+
+    def __repr__(self):
+        return '<Handler #%s>' % self._serial
+        
     def bind(self):
-        self.graphClients.addSseHandler(self, self.id)
+        self.graphClients.addSseHandler(self)
         
     def unbind(self):
         self.graphClients.removeSseHandler(self)
--- a/service/reasoning/twisted_sse_demo/eventsource.py	Sat Aug 20 23:34:04 2016 -0700
+++ b/service/reasoning/twisted_sse_demo/eventsource.py	Sun Aug 28 18:11:34 2016 -0700
@@ -6,7 +6,7 @@
 
 from sse_client import EventSourceProtocol
 
-#setup()
+setup()
 
 
 class EventSource(object):
@@ -20,7 +20,7 @@
         self.stashedError = None
         self.connect()
 
-    #@run_in_reactor
+    @run_in_reactor
     def connect(self):
         """
         Connect to the event source URL