Mercurial > code > home > repos > homeauto
changeset 1382:f883166f7ca1
big rewrite. now probably works for multiple subscriptions and over reconnects
Ignore-this: 301b82746e517d2a6ff212677f23ca8e
darcs-hash:c3badf7258cd931afa0f9b0507482d0c6a702407
author | drewp <drewp@bigasterisk.com> |
---|---|
date | Wed, 08 May 2019 00:55:58 -0700 |
parents | c9aedf4c6e52 |
children | d66790a031ea |
files | lib/mqtt_client/mqtt_client.py |
diffstat | 1 files changed, 43 insertions(+), 38 deletions(-) [+] |
line wrap: on
line diff
--- a/lib/mqtt_client/mqtt_client.py Mon May 06 21:11:19 2019 -0700 +++ b/lib/mqtt_client/mqtt_client.py Wed May 08 00:55:58 2019 -0700 @@ -1,8 +1,9 @@ import logging from mqtt.client.factory import MQTTFactory from rx import Observable +from rx.subjects import Subject from rx.concurrency import TwistedScheduler -from twisted.application.internet import ClientService, backoffPolicy +from twisted.application.internet import ClientService from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, Deferred from twisted.internet.endpoints import clientFromString @@ -11,34 +12,56 @@ class MQTTService(ClientService): - def __init__(self, endpoint, factory): + def __init__(self, endpoint, factory, observersByTopic): self.endpoint = endpoint - ClientService.__init__(self, endpoint, factory, retryPolicy=backoffPolicy()) + self.observersByTopic = observersByTopic + ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5) def startService(self): self.whenConnected().addCallback(self.connectToBroker) ClientService.startService(self) + def ensureSubscribed(self, topic: bytes): + self.whenConnected().addCallback(self._subscribeToLatestTopic, topic) + + def _subscribeToLatestTopic(self, protocol, topic: bytes): + if protocol.state == protocol.CONNECTED: + self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)]) + # else it'll get done in the next connectToBroker. + + def _subscribeAll(self): + topics = list(self.observersByTopic) + log.info('subscribing %r', topics) + self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics]) + + @inlineCallbacks def connectToBroker(self, protocol): self.protocol = protocol - self.protocol.onDisconnection = self.onDisconnection - # We are issuing 3 publish in a row - # if order matters, then set window size to 1 + self.protocol.onDisconnection = self._onProtocolDisconnection + # Publish requests beyond window size are enqueued self.protocol.setWindowSize(1) try: yield self.protocol.connect("TwistedMQTT-pub", keepalive=60) except Exception as e: - log.error("Connecting to {broker} raised {excp!s}", - broker=self.endpoint, excp=e) - else: - log.info("Connected to {broker}".format(broker=self.endpoint)) - if getattr(self, 'onMqttConnectionMade', False): - self.onMqttConnectionMade() + log.error(f"Connecting to {self.endpoint} raised {e!s}") + return + + log.info(f"Connected to {self.endpoint}") - def onDisconnection(self, reason): + self.protocol.onPublish = self._onProtocolMessage + self._subscribeAll() + + def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId): + topic = topic.encode('ascii') + observers = self.observersByTopic.get(topic, []) + log.debug(f'received {topic} payload {payload} ({len(observers)} obs)') + for obs in observers: + obs.on_next(payload) + + def _onProtocolDisconnection(self, reason): log.warn("Connection to broker lost: %r", reason) self.whenConnected().addCallback(self.connectToBroker) @@ -54,38 +77,20 @@ class MqttClient(object): def __init__(self, brokerHost='bang', brokerPort=1883): - #scheduler = TwistedScheduler(reactor) + self.observersByTopic = {} # bytes: Set(observer) factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) - self.serv = MQTTService(myEndpoint, factory) + self.serv = MQTTService(myEndpoint, factory, self.observersByTopic) self.serv.startService() - - def publish(self, topic, msg): + + def publish(self, topic: bytes, msg: bytes): return self.serv.publish(topic, msg) def subscribe(self, topic: bytes): """returns rx.Observable of payload strings""" - # This is surely broken for multiple topics and subscriptions. Might not even - # work over a reconnect. - - ret = Observable.create(self._observe_msgs) - - self.serv.onMqttConnectionMade = lambda: self._resubscribe(topic) - if (hasattr(self.serv, 'protocol') and - self.serv.protocol.state ==self.serv.protocol.CONNECTED): - self._resubscribe(topic) + ret = Subject() + self.observersByTopic.setdefault(topic, set()).add(ret) + self.serv.ensureSubscribed(topic) return ret - - def _resubscribe(self, topic: bytes): - log.info('subscribing %r', topic) - self.serv.protocol.onPublish = self._onPublish - return self.serv.protocol.subscribe(topics=[(topic.decode('utf-8'), 2)]) - - def _observe_msgs(self, observer): - self.obs = observer - - def _onPublish(self, topic, payload, qos, dup, retain, msgId): - log.debug('received payload %r', payload) - self.obs.on_next(payload)