Mercurial > code > home > repos > mqtt_client
comparison mqtt_client.py @ 0:834594523aa4
move from homeauto repo
author | drewp@bigasterisk.com |
---|---|
date | Wed, 24 Nov 2021 10:06:04 -0800 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:834594523aa4 |
---|---|
1 import logging | |
2 from mqtt.client.factory import MQTTFactory | |
3 import rx.subject | |
4 from twisted.application.internet import ClientService | |
5 from twisted.internet import reactor | |
6 from twisted.internet.defer import inlineCallbacks | |
7 from twisted.internet.endpoints import clientFromString | |
8 | |
9 log = logging.getLogger('mqtt_client') | |
10 AT_MOST_ONCE, AT_LEAST_ONCE, EXACTLY_ONCE = 0, 1, 2 | |
11 | |
12 class MQTTService(ClientService): | |
13 | |
14 def __init__(self, endpoint, factory, observersByTopic, clientId): | |
15 self.endpoint = endpoint | |
16 self.observersByTopic = observersByTopic | |
17 self.clientId = clientId | |
18 ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5) | |
19 | |
20 def startService(self): | |
21 self.whenConnected().addCallback(self.connectToBroker) | |
22 ClientService.startService(self) | |
23 | |
24 def ensureSubscribed(self, topic: bytes): | |
25 self.whenConnected().addCallback(self._subscribeToLatestTopic, topic) | |
26 | |
27 def _subscribeToLatestTopic(self, protocol, topic: bytes): | |
28 if protocol.state == protocol.CONNECTED: | |
29 self.protocol.subscribe(topics=[(topic.decode('utf8'), AT_LEAST_ONCE)]) | |
30 # else it'll get done in the next connectToBroker. | |
31 | |
32 def _subscribeAll(self): | |
33 topics = list(self.observersByTopic) | |
34 if not topics: | |
35 return | |
36 log.info('subscribing %r', topics) | |
37 self.protocol.subscribe(topics=[(topic.decode('utf8'), AT_LEAST_ONCE) for topic in topics]) | |
38 | |
39 | |
40 @inlineCallbacks | |
41 def connectToBroker(self, protocol): | |
42 self.protocol = protocol | |
43 self.protocol.onDisconnection = self._onProtocolDisconnection | |
44 | |
45 # Publish requests beyond window size are enqueued | |
46 self.protocol.setWindowSize(1) | |
47 | |
48 try: | |
49 yield self.protocol.connect(self.clientId, keepalive=60) | |
50 except Exception as e: | |
51 log.error(f"Connecting to {self.endpoint} raised {e!s}") | |
52 return | |
53 | |
54 log.info(f"Connected to {self.endpoint}") | |
55 | |
56 self.protocol.onPublish = self._onProtocolMessage | |
57 self._subscribeAll() | |
58 | |
59 def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId): | |
60 topic = topic.encode('ascii') | |
61 observers = self.observersByTopic.get(topic, []) | |
62 log.debug(f'received {topic} payload {payload} ({len(observers)} obs)') | |
63 for obs in observers: | |
64 obs.on_next(payload) | |
65 | |
66 def _onProtocolDisconnection(self, reason): | |
67 log.warn("Connection to broker lost: %r", reason) | |
68 self.whenConnected().addCallback(self.connectToBroker) | |
69 | |
70 def publish(self, topic: bytes, msg: bytes): | |
71 def _logFailure(failure): | |
72 log.warn("publish failed: %s", failure.getErrorMessage()) | |
73 return failure | |
74 | |
75 return self.protocol.publish(topic=topic.decode('utf-8'), qos=0, | |
76 message=bytearray(msg)).addErrback(_logFailure) | |
77 | |
78 | |
79 class MqttClient(object): | |
80 def __init__(self, clientId, brokerHost='bang', brokerPort=1883): | |
81 | |
82 self.observersByTopic = {} # bytes: Set(observer) | |
83 | |
84 factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER) | |
85 myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort)) | |
86 myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port)) | |
87 self.serv = MQTTService(myEndpoint, factory, self.observersByTopic, | |
88 clientId) | |
89 self.serv.startService() | |
90 | |
91 def publish(self, topic: bytes, msg: bytes): | |
92 return self.serv.publish(topic, msg) | |
93 | |
94 def subscribe(self, topic: bytes): | |
95 """returns rx.Observable of payload strings""" | |
96 ret = rx.subject.Subject() | |
97 self.observersByTopic.setdefault(topic, set()).add(ret) | |
98 self.serv.ensureSubscribed(topic) | |
99 return ret |