changeset 14:e3dbd04dab96

add mqtt; talk to first light (no throttling)
author drewp@bigasterisk.com
date Sun, 28 Jan 2024 20:49:42 -0800
parents 1c865af058e7
children 61d4ccecfed8
files light.py light_bridge.py mqtt_io.py pdm.lock pyproject.toml
diffstat 5 files changed, 196 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/light.py	Sun Jan 28 20:03:20 2024 -0800
+++ b/light.py	Sun Jan 28 20:49:42 2024 -0800
@@ -1,9 +1,11 @@
 import asyncio
+import json
 import logging
 from dataclasses import dataclass
 from typing import Callable
 
 from color import Color
+from mqtt_io import MqttIo
 
 log = logging.getLogger('light')
 
@@ -22,26 +24,42 @@
         return dict([(k, round(v, 3)) for k, v in self.__dict__.items() if v > 0])
 
 
-class Address:
+class Transport:
 
     def linked(self):
         return {'label': str(self)}
 
+    async def send(self, dc: DeviceColor):
+        raise TypeError
 
-class ZigbeeAddress(Address):
 
-    def __init__(self, name: str, ieee: str):
+def zigbeeHexMessage(color: DeviceColor, bw=False) -> dict:
+    bright = max(color.r, color.g, color.b)
+    msg: dict = {"transition": 0, "brightness": int(255 * bright)}
+    if not bw:
+        c = "#%02x%02x%02x" % (int(color.r * 255), int(color.g * 255), int(color.b * 255))
+        msg["color"] = {"hex": c}
+    return msg
+
+
+class ZigbeeTransport(Transport):
+
+    def __init__(self, mqtt: MqttIo, name: str, ieee: str):
+        self.mqtt = mqtt
         self.name = name
         self.ieee = ieee
 
     def linked(self):
         return {'url': f'https://bigasterisk.com/zigbee/console/#/device/{self.ieee}/info', 'label': 'do-bar'}
 
+    async def send(self, dc: DeviceColor):
+        await self.mqtt.publish(f'zigbee/{self.name}/set', json.dumps(zigbeeHexMessage(dc, bw=False)))
+
 
 @dataclass
 class Light:
     name: str
-    address: Address
+    address: Transport
 
     requestingColor: Color = Color.fromHex('#000000')
     requestingDeviceColor: DeviceColor = DeviceColor()
@@ -78,19 +96,29 @@
             return
         self.requestingColor = c
         self.requestingDeviceColor = self.deviceColor(self.requestingColor)
+
+        if self.notifyChanged:
+            self.notifyChanged()
+
+        # waits for the relevant round-trip
+        log.info(f'transport  send {self.requestingDeviceColor}')
+        await self.address.send(self.requestingDeviceColor)
+
+        self.emittingColor = self.requestingColor
         if self.notifyChanged:
             self.notifyChanged()
 
 
-def makeZbBar(name: str, ieee: str) -> Light:
-    return Light(name=name, address=ZigbeeAddress(name, ieee))
+def makeZbBar(mqtt: MqttIo, name: str, ieee: str) -> Light:
+    return Light(name=name, address=ZigbeeTransport(mqtt, name, ieee))
 
 
 class Lights:
     _d: dict[str, Light] = {}
 
-    def __init__(self):
-        self.add(makeZbBar('do-bar', '0xa4c13844948d2da4'))
+    def __init__(self, mqtt: MqttIo):
+        self.mqtt = mqtt
+        self.add(makeZbBar(mqtt, 'do-bar', '0xa4c13844948d2da4'))
 
     def add(self, d: Light):
         d.notifyChanged = self.notifyChanged
--- a/light_bridge.py	Sun Jan 28 20:03:20 2024 -0800
+++ b/light_bridge.py	Sun Jan 28 20:49:42 2024 -0800
@@ -18,7 +18,7 @@
 from starlette_exporter import PrometheusMiddleware, handle_metrics
 
 from light import Lights
-
+from mqtt_io import MqttIo
 logging.basicConfig(level=logging.INFO)
 log = logging.getLogger()
 
@@ -46,7 +46,8 @@
 
 
 def main():
-    lights = Lights()
+    mqtt = MqttIo()
+    lights = Lights(mqtt)
     graph = PatchableGraph()
     app = Starlette(debug=True,
                     routes=[
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mqtt_io.py	Sun Jan 28 20:49:42 2024 -0800
@@ -0,0 +1,132 @@
+import asyncio
+import inspect
+import json
+import logging
+import time
+from typing import Callable, cast
+import weakref
+
+import aiomqtt  # v 2.0.0
+from prometheus_client import Gauge
+
+log = logging.getLogger('Mqtt')
+
+MQTT_CONNECTED = Gauge('mqtt_connected', 'mqtt is connected')
+
+
+class CurrentSubs:
+    """the mqtt topics we're watching and the (still alive refs to) the
+    callbacks who want them. This layer works before we're connected to the mqtt
+    broker"""
+
+    def __init__(self, mqttSub, mqttUnsub):
+        self._subs: dict[str, list[weakref.ref]] = {}  # {topic : set(ref(cb))}
+        self.mqttSub = mqttSub
+        self.mqttUnsub = mqttUnsub
+        self.pendingTopicSubs = set()
+        self.connected = False
+
+    async def subscribe(self, topic: str, cb: Callable[[float, str], None]):
+
+        topicIsNew = False
+        if topic not in self._subs:
+            self._subs[topic] = []
+            topicIsNew = True
+
+        if inspect.ismethod(cb):
+            ref = weakref.WeakMethod(cb, lambda _: self._cbDeleted(topic))
+        else:
+            ref = weakref.ref(cb, lambda _: self._cbDeleted(topic))
+
+        self._subs[topic].append(ref)
+
+        self.dumpSubs()
+
+        if topicIsNew:  # don't sub until our handler is added, in case retained msgs come (I don't know if they can or not)
+            if self.connected:
+                log.info(f'  we"re connected so lets add a real sub to {topic!r}')
+                await self.mqttSub(topic)
+            else:
+                log.info(f"  connection wait, trying to subscribe to {topic!r}")
+                self.pendingTopicSubs.add(topic)
+
+    def dumpSubs(self):
+        log.info('  now _subs is')
+        for k in self._subs:
+            self._subs[k] = [v for v in self._subs[k] if v() is not None]
+            log.info(f'    - {k} {self._subs[k]}')
+
+    def _cbDeleted(self, topic: str):
+        log.info(f'cb removed under {topic}')
+        if topic not in self._subs:
+            return
+        self._subs[topic] = [v for v in self._subs[topic] if v() is not None]
+        if not self._subs[topic]:
+            log.info(f'sohuld unsub {topic}')
+            asyncio.create_task(self.mqttUnsub(topic))
+            del self._subs[topic]
+
+    async def onMqttConnected(self):
+        log.info(f'mqtt connected. Make {len(self.pendingTopicSubs)}  pending subs')
+        self.connected = True
+        for p in self.pendingTopicSubs:
+            await self.mqttSub(p)
+
+        log.info('done with pending subs')
+        self.pendingTopicSubs = set()
+
+    def onMessage(self, message: aiomqtt.Message):
+        topic = message.topic.value
+        for cbRef in self._subs.get(topic, set()):
+            cb = cbRef()
+            if cb is None:
+                raise ValueError("we should have pruned this sub already")
+            try:
+                cb(time.time(), cast(bytes, message.payload).decode('utf-8'))
+            except Exception:
+                log.error(f"in callback for {topic=}", exc_info=True)
+
+
+class MqttIo:
+
+    def __init__(self):
+        self.devices = []
+        client_id = "light-bridge"
+        log.info('starting mqtt task')
+        MQTT_CONNECTED.set(0)
+        self.client = aiomqtt.Client('mqtt2.bigasterisk.com', identifier=client_id)
+        self.subs = CurrentSubs(self.client.subscribe, self.client.unsubscribe)
+        self._task = asyncio.create_task(self._run())
+        log.info('started mqtt task')
+
+    def assertRunning(self):
+        if self._task.done():
+            raise ValueError("Mqtt task is not running")
+
+    async def _run(self):
+        try:
+            await self._connectAndRead()
+        except aiomqtt.MqttError as e:
+            MQTT_CONNECTED.set(0)
+            log.error(e, exc_info=True)
+
+    async def _connectAndRead(self):
+        async with self.client:
+            await self.subs.onMqttConnected()
+            MQTT_CONNECTED.set(1)
+            async for message in self.client.messages:
+                self.subs.onMessage(message)
+
+    async def subscribe(self, topic: str, cb: Callable[[float, str], None]):
+        """when a messages comes on this topic, call cb with the time and payload.
+        """
+        await self.subs.subscribe(topic, cb)
+
+    async def publish(self, topic: str, msg: str):
+        '''best effort'''
+        if not self.subs.connected:
+            log.error('publish ignored- not connected', exc_info=True)
+            return
+
+        log.info(f'client.publish {topic=} {msg=}')
+        await self.client.publish(topic, msg)
--- a/pdm.lock	Sun Jan 28 20:03:20 2024 -0800
+++ b/pdm.lock	Sun Jan 28 20:49:42 2024 -0800
@@ -5,7 +5,7 @@
 groups = ["default"]
 strategy = ["cross_platform", "inherit_metadata"]
 lock_version = "4.4.1"
-content_hash = "sha256:238744667aa5bb7e379e16e2b9f424307a9a90b29eafd1734283956bf2b8fdec"
+content_hash = "sha256:b3fe693267977a2e72f6da23fbd5096fce62a9d204550479e044796d39d8043f"
 
 [[package]]
 name = "aiohttp"
@@ -55,6 +55,20 @@
 ]
 
 [[package]]
+name = "aiomqtt"
+version = "2.0.0"
+requires_python = ">=3.8,<4.0"
+summary = "The idiomatic asyncio MQTT client, wrapped around paho-mqtt"
+groups = ["default"]
+dependencies = [
+    "paho-mqtt<2.0.0,>=1.6.0",
+]
+files = [
+    {file = "aiomqtt-2.0.0-py3-none-any.whl", hash = "sha256:f3b97eca4a5a2c40769ed14f660520f733be1d2ec383a9976153fe49141e2fa2"},
+    {file = "aiomqtt-2.0.0.tar.gz", hash = "sha256:3d480429334bdba4e4b9936c6cc198ea4f76a94d36cf294e0f713ec59f6a2120"},
+]
+
+[[package]]
 name = "aiosignal"
 version = "1.3.1"
 requires_python = ">=3.7"
@@ -430,6 +444,15 @@
 ]
 
 [[package]]
+name = "paho-mqtt"
+version = "1.6.1"
+summary = "MQTT version 5.0/3.1.1 client class"
+groups = ["default"]
+files = [
+    {file = "paho-mqtt-1.6.1.tar.gz", hash = "sha256:2a8291c81623aec00372b5a85558a372c747cbca8e9934dfe218638b8eefc26f"},
+]
+
+[[package]]
 name = "patchablegraph"
 version = "1.5.0"
 requires_python = ">=3.9"
--- a/pyproject.toml	Sun Jan 28 20:03:20 2024 -0800
+++ b/pyproject.toml	Sun Jan 28 20:49:42 2024 -0800
@@ -14,6 +14,7 @@
     "patchablegraph>=1.5.0",
     "rdfdb==0.24.0",
     "dataclasses-json>=0.6.3",
+    "aiomqtt>=2.0.0",
 ]
 requires-python = ">=3.11"
 license = {text = "MIT"}