Mercurial > code > home > repos > homeauto
view lib/mqtt_client/mqtt_client.py @ 579:603272ee3000
big rewrite. now probably works for multiple subscriptions and over reconnects
Ignore-this: 301b82746e517d2a6ff212677f23ca8e
author | drewp@bigasterisk.com |
---|---|
date | Wed, 08 May 2019 00:55:58 -0700 |
parents | 3501410e4cc7 |
children | 6b6a7d06691e |
line wrap: on
line source
import logging from mqtt.client.factory import MQTTFactory from rx import Observable from rx.subjects import Subject from rx.concurrency import TwistedScheduler from twisted.application.internet import ClientService from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, Deferred from twisted.internet.endpoints import clientFromString log = logging.getLogger('mqtt_client') class MQTTService(ClientService): def __init__(self, endpoint, factory, observersByTopic): self.endpoint = endpoint self.observersByTopic = observersByTopic ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5) def startService(self): self.whenConnected().addCallback(self.connectToBroker) ClientService.startService(self) def ensureSubscribed(self, topic: bytes): self.whenConnected().addCallback(self._subscribeToLatestTopic, topic) def _subscribeToLatestTopic(self, protocol, topic: bytes): if protocol.state == protocol.CONNECTED: self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)]) # else it'll get done in the next connectToBroker. def _subscribeAll(self): topics = list(self.observersByTopic) log.info('subscribing %r', topics) self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics]) @inlineCallbacks def connectToBroker(self, protocol): self.protocol = protocol self.protocol.onDisconnection = self._onProtocolDisconnection # Publish requests beyond window size are enqueued self.protocol.setWindowSize(1) try: yield self.protocol.connect("TwistedMQTT-pub", keepalive=60) except Exception as e: log.error(f"Connecting to {self.endpoint} raised {e!s}") return log.info(f"Connected to {self.endpoint}") self.protocol.onPublish = self._onProtocolMessage self._subscribeAll() def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId): topic = topic.encode('ascii') observers = self.observersByTopic.get(topic, []) log.debug(f'received {topic} payload {payload} ({len(observers)} obs)') for obs in observers: obs.on_next(payload) def _onProtocolDisconnection(self, reason): log.warn("Connection to broker lost: %r", reason) self.whenConnected().addCallback(self.connectToBroker) def publish(self, topic: bytes, msg: bytes): def _logFailure(failure): log.warn("publish failed: %s", failure.getErrorMessage()) return failure return self.protocol.publish(topic=topic.decode('utf-8'), qos=0, message=bytearray(msg)).addErrback(_logFailure) class MqttClient(object): def __init__(self, brokerHost='bang', brokerPort=1883): self.observersByTopic = {} # bytes: Set(observer) factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) self.serv = MQTTService(myEndpoint, factory, self.observersByTopic) self.serv.startService() def publish(self, topic: bytes, msg: bytes): return self.serv.publish(topic, msg) def subscribe(self, topic: bytes): """returns rx.Observable of payload strings""" ret = Subject() self.observersByTopic.setdefault(topic, set()).add(ret) self.serv.ensureSubscribed(topic) return ret