view lib/mqtt_client/mqtt_client.py @ 1382:f883166f7ca1

big rewrite. now probably works for multiple subscriptions and over reconnects Ignore-this: 301b82746e517d2a6ff212677f23ca8e darcs-hash:c3badf7258cd931afa0f9b0507482d0c6a702407
author drewp <drewp@bigasterisk.com>
date Wed, 08 May 2019 00:55:58 -0700
parents 3cf19717cb6f
children c887b1cc5e83
line wrap: on
line source

import logging
from mqtt.client.factory import MQTTFactory
from rx import Observable
from rx.subjects import Subject
from rx.concurrency import TwistedScheduler
from twisted.application.internet import ClientService
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, Deferred
from twisted.internet.endpoints import clientFromString

log = logging.getLogger('mqtt_client')

class MQTTService(ClientService):

    def __init__(self, endpoint, factory, observersByTopic):
        self.endpoint = endpoint
        self.observersByTopic = observersByTopic
        ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5)

    def startService(self):
        self.whenConnected().addCallback(self.connectToBroker)
        ClientService.startService(self)

    def ensureSubscribed(self, topic: bytes):
        self.whenConnected().addCallback(self._subscribeToLatestTopic, topic)

    def _subscribeToLatestTopic(self, protocol, topic: bytes):
        if protocol.state == protocol.CONNECTED:
            self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)])
        # else it'll get done in the next connectToBroker.

    def _subscribeAll(self):
        topics = list(self.observersByTopic)
        log.info('subscribing %r', topics)
        self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics])

        
    @inlineCallbacks
    def connectToBroker(self, protocol):
        self.protocol = protocol
        self.protocol.onDisconnection = self._onProtocolDisconnection

        # Publish requests beyond window size are enqueued
        self.protocol.setWindowSize(1)

        try:
            yield self.protocol.connect("TwistedMQTT-pub", keepalive=60)
        except Exception as e:
            log.error(f"Connecting to {self.endpoint} raised {e!s}")
            return
        
        log.info(f"Connected to {self.endpoint}")

        self.protocol.onPublish = self._onProtocolMessage
        self._subscribeAll()
            
    def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId):
        topic = topic.encode('ascii')
        observers = self.observersByTopic.get(topic, [])
        log.debug(f'received {topic} payload {payload} ({len(observers)} obs)')
        for obs in observers:
            obs.on_next(payload)
            
    def _onProtocolDisconnection(self, reason):
        log.warn("Connection to broker lost: %r", reason)
        self.whenConnected().addCallback(self.connectToBroker)

    def publish(self, topic: bytes, msg: bytes):
        def _logFailure(failure):
            log.warn("publish failed: %s", failure.getErrorMessage())
            return failure

        return self.protocol.publish(topic=topic.decode('utf-8'), qos=0,
                                     message=bytearray(msg)).addErrback(_logFailure)


class MqttClient(object):
    def __init__(self, brokerHost='bang', brokerPort=1883):

        self.observersByTopic = {} # bytes: Set(observer)
        
        factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER)
        myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort))
        myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port))
        self.serv = MQTTService(myEndpoint, factory, self.observersByTopic)
        self.serv.startService()

    def publish(self, topic: bytes, msg: bytes):
        return self.serv.publish(topic, msg)

    def subscribe(self, topic: bytes):
        """returns rx.Observable of payload strings"""
        ret = Subject()
        self.observersByTopic.setdefault(topic, set()).add(ret)
        self.serv.ensureSubscribed(topic)
        return ret