Mercurial > code > home > repos > homeauto
view lib/mqtt_client/mqtt_client.py @ 580:9d60d3f34ddc
release 0.5.0
Ignore-this: 1ccfaf82e1b7a9d0fe4eda652e27a3dd
author | drewp@bigasterisk.com |
---|---|
date | Wed, 08 May 2019 00:56:54 -0700 |
parents | 603272ee3000 |
children | 6b6a7d06691e |
line wrap: on
line source
import logging from mqtt.client.factory import MQTTFactory from rx import Observable from rx.subjects import Subject from rx.concurrency import TwistedScheduler from twisted.application.internet import ClientService from twisted.internet import reactor from twisted.internet.defer import inlineCallbacks, Deferred from twisted.internet.endpoints import clientFromString log = logging.getLogger('mqtt_client') class MQTTService(ClientService): def __init__(self, endpoint, factory, observersByTopic): self.endpoint = endpoint self.observersByTopic = observersByTopic ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5) def startService(self): self.whenConnected().addCallback(self.connectToBroker) ClientService.startService(self) def ensureSubscribed(self, topic: bytes): self.whenConnected().addCallback(self._subscribeToLatestTopic, topic) def _subscribeToLatestTopic(self, protocol, topic: bytes): if protocol.state == protocol.CONNECTED: self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)]) # else it'll get done in the next connectToBroker. def _subscribeAll(self): topics = list(self.observersByTopic) log.info('subscribing %r', topics) self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics]) @inlineCallbacks def connectToBroker(self, protocol): self.protocol = protocol self.protocol.onDisconnection = self._onProtocolDisconnection # Publish requests beyond window size are enqueued self.protocol.setWindowSize(1) try: yield self.protocol.connect("TwistedMQTT-pub", keepalive=60) except Exception as e: log.error(f"Connecting to {self.endpoint} raised {e!s}") return log.info(f"Connected to {self.endpoint}") self.protocol.onPublish = self._onProtocolMessage self._subscribeAll() def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId): topic = topic.encode('ascii') observers = self.observersByTopic.get(topic, []) log.debug(f'received {topic} payload {payload} ({len(observers)} obs)') for obs in observers: obs.on_next(payload) def _onProtocolDisconnection(self, reason): log.warn("Connection to broker lost: %r", reason) self.whenConnected().addCallback(self.connectToBroker) def publish(self, topic: bytes, msg: bytes): def _logFailure(failure): log.warn("publish failed: %s", failure.getErrorMessage()) return failure return self.protocol.publish(topic=topic.decode('utf-8'), qos=0, message=bytearray(msg)).addErrback(_logFailure) class MqttClient(object): def __init__(self, brokerHost='bang', brokerPort=1883): self.observersByTopic = {} # bytes: Set(observer) factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) self.serv = MQTTService(myEndpoint, factory, self.observersByTopic) self.serv.startService() def publish(self, topic: bytes, msg: bytes): return self.serv.publish(topic, msg) def subscribe(self, topic: bytes): """returns rx.Observable of payload strings""" ret = Subject() self.observersByTopic.setdefault(topic, set()).add(ret) self.serv.ensureSubscribed(topic) return ret