Mercurial > code > home > repos > homeauto
comparison lib/mqtt_client/mqtt_client.py @ 1382:f883166f7ca1
big rewrite. now probably works for multiple subscriptions and over reconnects
Ignore-this: 301b82746e517d2a6ff212677f23ca8e
darcs-hash:c3badf7258cd931afa0f9b0507482d0c6a702407
author | drewp <drewp@bigasterisk.com> |
---|---|
date | Wed, 08 May 2019 00:55:58 -0700 |
parents | 3cf19717cb6f |
children | c887b1cc5e83 |
comparison
equal
deleted
inserted
replaced
1381:c9aedf4c6e52 | 1382:f883166f7ca1 |
---|---|
1 import logging | 1 import logging |
2 from mqtt.client.factory import MQTTFactory | 2 from mqtt.client.factory import MQTTFactory |
3 from rx import Observable | 3 from rx import Observable |
4 from rx.subjects import Subject | |
4 from rx.concurrency import TwistedScheduler | 5 from rx.concurrency import TwistedScheduler |
5 from twisted.application.internet import ClientService, backoffPolicy | 6 from twisted.application.internet import ClientService |
6 from twisted.internet import reactor | 7 from twisted.internet import reactor |
7 from twisted.internet.defer import inlineCallbacks, Deferred | 8 from twisted.internet.defer import inlineCallbacks, Deferred |
8 from twisted.internet.endpoints import clientFromString | 9 from twisted.internet.endpoints import clientFromString |
9 | 10 |
10 log = logging.getLogger('mqtt_client') | 11 log = logging.getLogger('mqtt_client') |
11 | 12 |
12 class MQTTService(ClientService): | 13 class MQTTService(ClientService): |
13 | 14 |
14 def __init__(self, endpoint, factory): | 15 def __init__(self, endpoint, factory, observersByTopic): |
15 self.endpoint = endpoint | 16 self.endpoint = endpoint |
16 ClientService.__init__(self, endpoint, factory, retryPolicy=backoffPolicy()) | 17 self.observersByTopic = observersByTopic |
18 ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5) | |
17 | 19 |
18 def startService(self): | 20 def startService(self): |
19 self.whenConnected().addCallback(self.connectToBroker) | 21 self.whenConnected().addCallback(self.connectToBroker) |
20 ClientService.startService(self) | 22 ClientService.startService(self) |
21 | 23 |
24 def ensureSubscribed(self, topic: bytes): | |
25 self.whenConnected().addCallback(self._subscribeToLatestTopic, topic) | |
26 | |
27 def _subscribeToLatestTopic(self, protocol, topic: bytes): | |
28 if protocol.state == protocol.CONNECTED: | |
29 self.protocol.subscribe(topics=[(topic.decode('utf8'), 2)]) | |
30 # else it'll get done in the next connectToBroker. | |
31 | |
32 def _subscribeAll(self): | |
33 topics = list(self.observersByTopic) | |
34 log.info('subscribing %r', topics) | |
35 self.protocol.subscribe(topics=[(topic.decode('utf8'), 2) for topic in topics]) | |
36 | |
37 | |
22 @inlineCallbacks | 38 @inlineCallbacks |
23 def connectToBroker(self, protocol): | 39 def connectToBroker(self, protocol): |
24 self.protocol = protocol | 40 self.protocol = protocol |
25 self.protocol.onDisconnection = self.onDisconnection | 41 self.protocol.onDisconnection = self._onProtocolDisconnection |
26 # We are issuing 3 publish in a row | 42 |
27 # if order matters, then set window size to 1 | |
28 # Publish requests beyond window size are enqueued | 43 # Publish requests beyond window size are enqueued |
29 self.protocol.setWindowSize(1) | 44 self.protocol.setWindowSize(1) |
30 | 45 |
31 try: | 46 try: |
32 yield self.protocol.connect("TwistedMQTT-pub", keepalive=60) | 47 yield self.protocol.connect("TwistedMQTT-pub", keepalive=60) |
33 except Exception as e: | 48 except Exception as e: |
34 log.error("Connecting to {broker} raised {excp!s}", | 49 log.error(f"Connecting to {self.endpoint} raised {e!s}") |
35 broker=self.endpoint, excp=e) | 50 return |
36 else: | 51 |
37 log.info("Connected to {broker}".format(broker=self.endpoint)) | 52 log.info(f"Connected to {self.endpoint}") |
38 if getattr(self, 'onMqttConnectionMade', False): | |
39 self.onMqttConnectionMade() | |
40 | 53 |
41 def onDisconnection(self, reason): | 54 self.protocol.onPublish = self._onProtocolMessage |
55 self._subscribeAll() | |
56 | |
57 def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId): | |
58 topic = topic.encode('ascii') | |
59 observers = self.observersByTopic.get(topic, []) | |
60 log.debug(f'received {topic} payload {payload} ({len(observers)} obs)') | |
61 for obs in observers: | |
62 obs.on_next(payload) | |
63 | |
64 def _onProtocolDisconnection(self, reason): | |
42 log.warn("Connection to broker lost: %r", reason) | 65 log.warn("Connection to broker lost: %r", reason) |
43 self.whenConnected().addCallback(self.connectToBroker) | 66 self.whenConnected().addCallback(self.connectToBroker) |
44 | 67 |
45 def publish(self, topic: bytes, msg: bytes): | 68 def publish(self, topic: bytes, msg: bytes): |
46 def _logFailure(failure): | 69 def _logFailure(failure): |
52 | 75 |
53 | 76 |
54 class MqttClient(object): | 77 class MqttClient(object): |
55 def __init__(self, brokerHost='bang', brokerPort=1883): | 78 def __init__(self, brokerHost='bang', brokerPort=1883): |
56 | 79 |
57 #scheduler = TwistedScheduler(reactor) | 80 self.observersByTopic = {} # bytes: Set(observer) |
58 | 81 |
59 factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) | 82 factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) |
60 myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) | 83 myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) |
61 myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) | 84 myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) |
62 self.serv = MQTTService(myEndpoint, factory) | 85 self.serv = MQTTService(myEndpoint, factory, self.observersByTopic) |
63 self.serv.startService() | 86 self.serv.startService() |
64 | 87 |
65 def publish(self, topic, msg): | 88 def publish(self, topic: bytes, msg: bytes): |
66 return self.serv.publish(topic, msg) | 89 return self.serv.publish(topic, msg) |
67 | 90 |
68 def subscribe(self, topic: bytes): | 91 def subscribe(self, topic: bytes): |
69 """returns rx.Observable of payload strings""" | 92 """returns rx.Observable of payload strings""" |
70 # This is surely broken for multiple topics and subscriptions. Might not even | 93 ret = Subject() |
71 # work over a reconnect. | 94 self.observersByTopic.setdefault(topic, set()).add(ret) |
72 | 95 self.serv.ensureSubscribed(topic) |
73 ret = Observable.create(self._observe_msgs) | |
74 | |
75 self.serv.onMqttConnectionMade = lambda: self._resubscribe(topic) | |
76 if (hasattr(self.serv, 'protocol') and | |
77 self.serv.protocol.state ==self.serv.protocol.CONNECTED): | |
78 self._resubscribe(topic) | |
79 return ret | 96 return ret |
80 | |
81 def _resubscribe(self, topic: bytes): | |
82 log.info('subscribing %r', topic) | |
83 self.serv.protocol.onPublish = self._onPublish | |
84 return self.serv.protocol.subscribe(topics=[(topic.decode('utf-8'), 2)]) | |
85 | |
86 def _observe_msgs(self, observer): | |
87 self.obs = observer | |
88 | |
89 def _onPublish(self, topic, payload, qos, dup, retain, msgId): | |
90 log.debug('received payload %r', payload) | |
91 self.obs.on_next(payload) |