comparison lib/mqtt_client/mqtt_client.py @ 1382:f883166f7ca1

big rewrite. now probably works for multiple subscriptions and over reconnects Ignore-this: 301b82746e517d2a6ff212677f23ca8e darcs-hash:c3badf7258cd931afa0f9b0507482d0c6a702407
author drewp <drewp@bigasterisk.com>
date Wed, 08 May 2019 00:55:58 -0700
parents 3cf19717cb6f
children c887b1cc5e83
comparison
equal deleted inserted replaced
1381:c9aedf4c6e52 1382:f883166f7ca1
1 import logging 1 import logging
2 from mqtt.client.factory import MQTTFactory 2 from mqtt.client.factory import MQTTFactory
3 from rx import Observable 3 from rx import Observable
4 from rx.subjects import Subject
4 from rx.concurrency import TwistedScheduler 5 from rx.concurrency import TwistedScheduler
5 from twisted.application.internet import ClientService, backoffPolicy 6 from twisted.application.internet import ClientService
6 from twisted.internet import reactor 7 from twisted.internet import reactor
7 from twisted.internet.defer import inlineCallbacks, Deferred 8 from twisted.internet.defer import inlineCallbacks, Deferred
8 from twisted.internet.endpoints import clientFromString 9 from twisted.internet.endpoints import clientFromString
9 10
10 log = logging.getLogger('mqtt_client') 11 log = logging.getLogger('mqtt_client')
11 12
12 class MQTTService(ClientService): 13 class MQTTService(ClientService):
13 14
14 def __init__(self, endpoint, factory): 15 def __init__(self, endpoint, factory, observersByTopic):
15 self.endpoint = endpoint 16 self.endpoint = endpoint
16 ClientService.__init__(self, endpoint, factory, retryPolicy=backoffPolicy()) 17 self.observersByTopic = observersByTopic
18 ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5)
17 19
18 def startService(self): 20 def startService(self):
19 self.whenConnected().addCallback(self.connectToBroker) 21 self.whenConnected().addCallback(self.connectToBroker)
20 ClientService.startService(self) 22 ClientService.startService(self)
21 23
24 def ensureSubscribed(self, topic: bytes):
25 self.whenConnected().addCallback(self._subscribeToLatestTopic, topic)
26
27 def _subscribeToLatestTopic(self, protocol, topic: bytes):
28 if protocol.state == protocol.CONNECTED:
29 self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)])
30 # else it'll get done in the next connectToBroker.
31
32 def _subscribeAll(self):
33 topics = list(self.observersByTopic)
34 log.info('subscribing %r', topics)
35 self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics])
36
37
22 @inlineCallbacks 38 @inlineCallbacks
23 def connectToBroker(self, protocol): 39 def connectToBroker(self, protocol):
24 self.protocol = protocol 40 self.protocol = protocol
25 self.protocol.onDisconnection = self.onDisconnection 41 self.protocol.onDisconnection = self._onProtocolDisconnection
26 # We are issuing 3 publish in a row 42
27 # if order matters, then set window size to 1
28 # Publish requests beyond window size are enqueued 43 # Publish requests beyond window size are enqueued
29 self.protocol.setWindowSize(1) 44 self.protocol.setWindowSize(1)
30 45
31 try: 46 try:
32 yield self.protocol.connect("TwistedMQTT-pub", keepalive=60) 47 yield self.protocol.connect("TwistedMQTT-pub", keepalive=60)
33 except Exception as e: 48 except Exception as e:
34 log.error("Connecting to {broker} raised {excp!s}", 49 log.error(f"Connecting to {self.endpoint} raised {e!s}")
35 broker=self.endpoint, excp=e) 50 return
36 else: 51
37 log.info("Connected to {broker}".format(broker=self.endpoint)) 52 log.info(f"Connected to {self.endpoint}")
38 if getattr(self, 'onMqttConnectionMade', False):
39 self.onMqttConnectionMade()
40 53
41 def onDisconnection(self, reason): 54 self.protocol.onPublish = self._onProtocolMessage
55 self._subscribeAll()
56
57 def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId):
58 topic = topic.encode('ascii')
59 observers = self.observersByTopic.get(topic, [])
60 log.debug(f'received {topic} payload {payload} ({len(observers)} obs)')
61 for obs in observers:
62 obs.on_next(payload)
63
64 def _onProtocolDisconnection(self, reason):
42 log.warn("Connection to broker lost: %r", reason) 65 log.warn("Connection to broker lost: %r", reason)
43 self.whenConnected().addCallback(self.connectToBroker) 66 self.whenConnected().addCallback(self.connectToBroker)
44 67
45 def publish(self, topic: bytes, msg: bytes): 68 def publish(self, topic: bytes, msg: bytes):
46 def _logFailure(failure): 69 def _logFailure(failure):
52 75
53 76
54 class MqttClient(object): 77 class MqttClient(object):
55 def __init__(self, brokerHost='bang', brokerPort=1883): 78 def __init__(self, brokerHost='bang', brokerPort=1883):
56 79
57 #scheduler = TwistedScheduler(reactor) 80 self.observersByTopic = {} # bytes: Set(observer)
58 81
59 factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) 82 factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER)
60 myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) 83 myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort))
61 myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) 84 myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port))
62 self.serv = MQTTService(myEndpoint, factory) 85 self.serv = MQTTService(myEndpoint, factory, self.observersByTopic)
63 self.serv.startService() 86 self.serv.startService()
64 87
65 def publish(self, topic, msg): 88 def publish(self, topic: bytes, msg: bytes):
66 return self.serv.publish(topic, msg) 89 return self.serv.publish(topic, msg)
67 90
68 def subscribe(self, topic: bytes): 91 def subscribe(self, topic: bytes):
69 """returns rx.Observable of payload strings""" 92 """returns rx.Observable of payload strings"""
70 # This is surely broken for multiple topics and subscriptions. Might not even 93 ret = Subject()
71 # work over a reconnect. 94 self.observersByTopic.setdefault(topic, set()).add(ret)
72 95 self.serv.ensureSubscribed(topic)
73 ret = Observable.create(self._observe_msgs)
74
75 self.serv.onMqttConnectionMade = lambda: self._resubscribe(topic)
76 if (hasattr(self.serv, 'protocol') and
77 self.serv.protocol.state ==self.serv.protocol.CONNECTED):
78 self._resubscribe(topic)
79 return ret 96 return ret
80
81 def _resubscribe(self, topic: bytes):
82 log.info('subscribing %r', topic)
83 self.serv.protocol.onPublish = self._onPublish
84 return self.serv.protocol.subscribe(topics=[(topic.decode('utf-8'), 2)])
85
86 def _observe_msgs(self, observer):
87 self.obs = observer
88
89 def _onPublish(self, topic, payload, qos, dup, retain, msgId):
90 log.debug('received payload %r', payload)
91 self.obs.on_next(payload)