changeset 579:603272ee3000

big rewrite. now probably works for multiple subscriptions and over reconnects Ignore-this: 301b82746e517d2a6ff212677f23ca8e
author drewp@bigasterisk.com
date Wed, 08 May 2019 00:55:58 -0700
parents 60f6d5c61003
children 9d60d3f34ddc
files lib/mqtt_client/mqtt_client.py
diffstat 1 files changed, 43 insertions(+), 38 deletions(-) [+]
line wrap: on
line diff
--- a/lib/mqtt_client/mqtt_client.py	Mon May 06 21:11:19 2019 -0700
+++ b/lib/mqtt_client/mqtt_client.py	Wed May 08 00:55:58 2019 -0700
@@ -1,8 +1,9 @@
 import logging
 from mqtt.client.factory import MQTTFactory
 from rx import Observable
+from rx.subjects import Subject
 from rx.concurrency import TwistedScheduler
-from twisted.application.internet import ClientService, backoffPolicy
+from twisted.application.internet import ClientService
 from twisted.internet import reactor
 from twisted.internet.defer import inlineCallbacks, Deferred
 from twisted.internet.endpoints import clientFromString
@@ -11,34 +12,56 @@
 
 class MQTTService(ClientService):
 
-    def __init__(self, endpoint, factory):
+    def __init__(self, endpoint, factory, observersByTopic):
         self.endpoint = endpoint
-        ClientService.__init__(self, endpoint, factory, retryPolicy=backoffPolicy())
+        self.observersByTopic = observersByTopic
+        ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5)
 
     def startService(self):
         self.whenConnected().addCallback(self.connectToBroker)
         ClientService.startService(self)
 
+    def ensureSubscribed(self, topic: bytes):
+        self.whenConnected().addCallback(self._subscribeToLatestTopic, topic)
+
+    def _subscribeToLatestTopic(self, protocol, topic: bytes):
+        if protocol.state == protocol.CONNECTED:
+            self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)])
+        # else it'll get done in the next connectToBroker.
+
+    def _subscribeAll(self):
+        topics = list(self.observersByTopic)
+        log.info('subscribing %r', topics)
+        self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics])
+
+        
     @inlineCallbacks
     def connectToBroker(self, protocol):
         self.protocol = protocol
-        self.protocol.onDisconnection = self.onDisconnection
-        # We are issuing 3 publish in a row
-        # if order matters, then set window size to 1
+        self.protocol.onDisconnection = self._onProtocolDisconnection
+
         # Publish requests beyond window size are enqueued
         self.protocol.setWindowSize(1)
 
         try:
             yield self.protocol.connect("TwistedMQTT-pub", keepalive=60)
         except Exception as e:
-            log.error("Connecting to {broker} raised {excp!s}",
-                      broker=self.endpoint, excp=e)
-        else:
-            log.info("Connected to {broker}".format(broker=self.endpoint))
-        if getattr(self, 'onMqttConnectionMade', False):
-            self.onMqttConnectionMade()
+            log.error(f"Connecting to {self.endpoint} raised {e!s}")
+            return
+        
+        log.info(f"Connected to {self.endpoint}")
 
-    def onDisconnection(self, reason):
+        self.protocol.onPublish = self._onProtocolMessage
+        self._subscribeAll()
+            
+    def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId):
+        topic = topic.encode('ascii')
+        observers = self.observersByTopic.get(topic, [])
+        log.debug(f'received {topic} payload {payload} ({len(observers)} obs)')
+        for obs in observers:
+            obs.on_next(payload)
+            
+    def _onProtocolDisconnection(self, reason):
         log.warn("Connection to broker lost: %r", reason)
         self.whenConnected().addCallback(self.connectToBroker)
 
@@ -54,38 +77,20 @@
 class MqttClient(object):
     def __init__(self, brokerHost='bang', brokerPort=1883):
 
-        #scheduler = TwistedScheduler(reactor)
+        self.observersByTopic = {} # bytes: Set(observer)
         
         factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER)
         myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort))
         myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port))
-        self.serv = MQTTService(myEndpoint, factory)
+        self.serv = MQTTService(myEndpoint, factory, self.observersByTopic)
         self.serv.startService()
-        
-    def publish(self, topic, msg):
+
+    def publish(self, topic: bytes, msg: bytes):
         return self.serv.publish(topic, msg)
 
     def subscribe(self, topic: bytes):
         """returns rx.Observable of payload strings"""
-        # This is surely broken for multiple topics and subscriptions. Might not even
-        # work over a reconnect.
-        
-        ret = Observable.create(self._observe_msgs)
-
-        self.serv.onMqttConnectionMade = lambda: self._resubscribe(topic)
-        if (hasattr(self.serv, 'protocol') and
-            self.serv.protocol.state ==self.serv.protocol.CONNECTED):
-            self._resubscribe(topic)
+        ret = Subject()
+        self.observersByTopic.setdefault(topic, set()).add(ret)
+        self.serv.ensureSubscribed(topic)
         return ret
-
-    def _resubscribe(self, topic: bytes):
-        log.info('subscribing %r', topic)
-        self.serv.protocol.onPublish = self._onPublish
-        return self.serv.protocol.subscribe(topics=[(topic.decode('utf-8'), 2)])
-        
-    def _observe_msgs(self, observer):
-        self.obs = observer
-
-    def _onPublish(self, topic, payload, qos, dup, retain, msgId):
-        log.debug('received payload %r', payload)
-        self.obs.on_next(payload)