changeset 452:a8073bcddd8b

rewrite sse_demo for py3, better connection close behavior Ignore-this: eda5b7fcd8914eb9a751ec8471626cea
author drewp@bigasterisk.com
date Fri, 19 Apr 2019 04:18:44 -0700
parents 17a556ddc5ac
children 9fd92202c886
files lib/twisted_sse_demo/eventsource.py lib/twisted_sse_demo/sse_client.py
diffstat 2 files changed, 36 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/lib/twisted_sse_demo/eventsource.py	Fri Apr 19 01:08:01 2019 -0700
+++ b/lib/twisted_sse_demo/eventsource.py	Fri Apr 19 04:18:44 2019 -0700
@@ -3,16 +3,18 @@
 from twisted.web.client import Agent
 from twisted.web.http_headers import Headers
 
-from sse_client import EventSourceProtocol
+from .sse_client import EventSourceProtocol
 
 
 class EventSource(object):
     """
     The main EventSource class
     """
-    def __init__(self, url):
+    def __init__(self, url, userAgent):
+        # type: (str, bytes)
         self.url = url
-        self.protocol = EventSourceProtocol()
+        self.userAgent = userAgent
+        self.protocol = EventSourceProtocol(self.onConnectionLost)
         self.errorHandler = None
         self.stashedError = None
         self.connect()
@@ -21,34 +23,36 @@
         """
         Connect to the event source URL
         """
-        agent = Agent(reactor)
+        agent = Agent(reactor, connectTimeout=5)
+        self.agent = agent
         d = agent.request(
-            'GET',
+            b'GET',
             self.url,
             Headers({
-                'User-Agent': ['Twisted SSE Client'],
-                'Cache-Control': ['no-cache'],
-                'Accept': ['text/event-stream; charset=utf-8'],
+                b'User-Agent': [self.userAgent],
+                b'Cache-Control': [b'no-cache'],
+                b'Accept': [b'text/event-stream; charset=utf-8'],
             }),
             None)
-        d.addErrback(self.connectError)
-        d.addCallback(self.cbRequest)
+        d.addCallbacks(self.cbRequest, self.connectError)
 
     def cbRequest(self, response):
         if response is None:
+            # seems out of spec, according to https://twistedmatrix.com/documents/current/api/twisted.web.iweb.IAgent.html
             raise ValueError('no response for url %r' % self.url)
         elif response.code != 200:
             self.callErrorHandler("non 200 response received: %d" %
                                   response.code)
         else:
-            finished = Deferred()
-            self.protocol.setFinishedDeferred(finished)
             response.deliverBody(self.protocol)
-            return finished
 
     def connectError(self, ignored):
         self.callErrorHandler("error connecting to endpoint: %s" % self.url)
 
+    def onConnectionLost(self, reason):
+        # overridden
+        reason.printDetailedTraceback()
+        
     def callErrorHandler(self, msg):
         if self.errorHandler:
             func, callInThread = self.errorHandler
@@ -68,6 +72,7 @@
         self.addEventListener('message', func, callInThread)
 
     def addEventListener(self, event, func, callInThread=False):
+        assert isinstance(event, bytes), event
         callback = func
         if callInThread:
             callback = lambda data: reactor.callInThread(func, data)
--- a/lib/twisted_sse_demo/sse_client.py	Fri Apr 19 01:08:01 2019 -0700
+++ b/lib/twisted_sse_demo/sse_client.py	Fri Apr 19 04:18:44 2019 -0700
@@ -2,13 +2,15 @@
 
 
 class EventSourceProtocol(LineReceiver):
-    def __init__(self):
+    def __init__(self, onConnectionLost):
+        self.onConnectionLost = onConnectionLost
+        self.delimiter = b'\n'
         self.MAX_LENGTH = 1 << 20
         self.callbacks = {}
         self.finished = None
         # Initialize the event and data buffers
-        self.event = 'message'
-        self.data = ''
+        self.event = b'message'
+        self.data = b''
 
     def lineLengthExceeded(self, line):
         raise NotImplementedError('line too long')
@@ -20,47 +22,46 @@
         self.callbacks[event] = func
         
     def lineReceived(self, line):
-        if line == '':
+        if line == b'':
             # Dispatch event
             self.dispatchEvent()
         else:
             try:
-                field, value = line.split(':', 1)
+                field, value = line.split(b':', 1)
                 # If value starts with a space, strip it.
                 value = lstrip(value)
             except ValueError:
                 # We got a line with no colon, treat it as a field(ignore)
                 return
 
-            if field == '':
+            if field == b'':
                 # This is a comment; ignore
                 pass
-            elif field == 'data':
-                self.data += value + '\n'
-            elif field == 'event':
+            elif field == b'data':
+                self.data += value + b'\n'
+            elif field == b'event':
                 self.event = value
-            elif field == 'id':
+            elif field == b'id':
                 # Not implemented
                 pass
-            elif field == 'retry':
+            elif field == b'retry':
                 # Not implemented
                 pass
 
     def connectionLost(self, reason):
-        if self.finished:
-            self.finished.callback(None)
+        self.onConnectionLost(reason)
 
     def dispatchEvent(self):
         """
         Dispatch the event
         """
         # If last character is LF, strip it.
-        if self.data.endswith('\n'):
+        if self.data.endswith(b'\n'):
             self.data = self.data[:-1]
         if self.event in self.callbacks:
             self.callbacks[self.event](self.data)
-        self.data = ''
-        self.event = 'message'
+        self.data = b''
+        self.event = b'message'
 
 def lstrip(value):
-    return value[1:] if value.startswith(' ') else value
+    return value[1:] if value.startswith(b' ') else value