0
|
1 import logging
|
|
2 from mqtt.client.factory import MQTTFactory
|
|
3 import rx.subject
|
|
4 from twisted.application.internet import ClientService
|
|
5 from twisted.internet import reactor
|
|
6 from twisted.internet.defer import inlineCallbacks
|
|
7 from twisted.internet.endpoints import clientFromString
|
|
8
|
|
9 log = logging.getLogger('mqtt_client')
|
|
10 AT_MOST_ONCE, AT_LEAST_ONCE, EXACTLY_ONCE = 0, 1, 2
|
|
11
|
|
12 class MQTTService(ClientService):
|
|
13
|
|
14 def __init__(self, endpoint, factory, observersByTopic, clientId):
|
|
15 self.endpoint = endpoint
|
|
16 self.observersByTopic = observersByTopic
|
|
17 self.clientId = clientId
|
|
18 ClientService.__init__(self, endpoint, factory, retryPolicy=lambda _: 5)
|
|
19
|
|
20 def startService(self):
|
|
21 self.whenConnected().addCallback(self.connectToBroker)
|
|
22 ClientService.startService(self)
|
|
23
|
|
24 def ensureSubscribed(self, topic: bytes):
|
|
25 self.whenConnected().addCallback(self._subscribeToLatestTopic, topic)
|
|
26
|
|
27 def _subscribeToLatestTopic(self, protocol, topic: bytes):
|
|
28 if protocol.state == protocol.CONNECTED:
|
|
29 self.protocol.subscribe(topics=[(topic.decode('utf8'), AT_LEAST_ONCE)])
|
|
30 # else it'll get done in the next connectToBroker.
|
|
31
|
|
32 def _subscribeAll(self):
|
|
33 topics = list(self.observersByTopic)
|
|
34 if not topics:
|
|
35 return
|
|
36 log.info('subscribing %r', topics)
|
|
37 self.protocol.subscribe(topics=[(topic.decode('utf8'), AT_LEAST_ONCE) for topic in topics])
|
|
38
|
|
39
|
|
40 @inlineCallbacks
|
|
41 def connectToBroker(self, protocol):
|
|
42 self.protocol = protocol
|
|
43 self.protocol.onDisconnection = self._onProtocolDisconnection
|
|
44
|
|
45 # Publish requests beyond window size are enqueued
|
|
46 self.protocol.setWindowSize(1)
|
|
47
|
|
48 try:
|
|
49 yield self.protocol.connect(self.clientId, keepalive=60)
|
|
50 except Exception as e:
|
|
51 log.error(f"Connecting to {self.endpoint} raised {e!s}")
|
|
52 return
|
|
53
|
|
54 log.info(f"Connected to {self.endpoint}")
|
|
55
|
|
56 self.protocol.onPublish = self._onProtocolMessage
|
|
57 self._subscribeAll()
|
|
58
|
|
59 def _onProtocolMessage(self, topic, payload, qos, dup, retain, msgId):
|
|
60 topic = topic.encode('ascii')
|
|
61 observers = self.observersByTopic.get(topic, [])
|
|
62 log.debug(f'received {topic} payload {payload} ({len(observers)} obs)')
|
|
63 for obs in observers:
|
|
64 obs.on_next(payload)
|
|
65
|
|
66 def _onProtocolDisconnection(self, reason):
|
|
67 log.warn("Connection to broker lost: %r", reason)
|
|
68 self.whenConnected().addCallback(self.connectToBroker)
|
|
69
|
|
70 def publish(self, topic: bytes, msg: bytes):
|
|
71 def _logFailure(failure):
|
|
72 log.warn("publish failed: %s", failure.getErrorMessage())
|
|
73 return failure
|
|
74
|
|
75 return self.protocol.publish(topic=topic.decode('utf-8'), qos=0,
|
|
76 message=bytearray(msg)).addErrback(_logFailure)
|
|
77
|
|
78
|
|
79 class MqttClient(object):
|
|
80 def __init__(self, clientId, brokerHost='bang', brokerPort=1883):
|
|
81
|
|
82 self.observersByTopic = {} # bytes: Set(observer)
|
|
83
|
|
84 factory = MQTTFactory(profile=MQTTFactory.PUBLISHER | MQTTFactory.SUBSCRIBER)
|
|
85 myEndpoint = clientFromString(reactor, 'tcp:%s:%s' % (brokerHost, brokerPort))
|
|
86 myEndpoint.__class__.__repr__ = lambda self: repr('%s:%s' % (self._host, self._port))
|
|
87 self.serv = MQTTService(myEndpoint, factory, self.observersByTopic,
|
|
88 clientId)
|
|
89 self.serv.startService()
|
|
90
|
|
91 def publish(self, topic: bytes, msg: bytes):
|
|
92 return self.serv.publish(topic, msg)
|
|
93
|
|
94 def subscribe(self, topic: bytes):
|
|
95 """returns rx.Observable of payload strings"""
|
|
96 ret = rx.subject.Subject()
|
|
97 self.observersByTopic.setdefault(topic, set()).add(ret)
|
|
98 self.serv.ensureSubscribed(topic)
|
|
99 return ret
|