view mqtt_io.py @ 27:32cfefe3155b

try harder to crash if there's an mqtt error, so k8s does a full restart
author drewp@bigasterisk.com
date Sat, 23 Mar 2024 15:25:02 -0700
parents e3dbd04dab96
children
line wrap: on
line source

import asyncio
import inspect
import json
import logging
import os
import time
from typing import Callable, cast
import weakref

import aiomqtt  # v 2.0.0
from prometheus_client import Gauge

log = logging.getLogger('Mqtt')

MQTT_CONNECTED = Gauge('mqtt_connected', 'mqtt is connected')


class CurrentSubs:
    """the mqtt topics we're watching and the (still alive refs to) the
    callbacks who want them. This layer works before we're connected to the mqtt
    broker"""

    def __init__(self, mqttSub, mqttUnsub):
        self._subs: dict[str, list[weakref.ref]] = {}  # {topic : set(ref(cb))}
        self.mqttSub = mqttSub
        self.mqttUnsub = mqttUnsub
        self.pendingTopicSubs = set()
        self.connected = False

    async def subscribe(self, topic: str, cb: Callable[[float, str], None]):

        topicIsNew = False
        if topic not in self._subs:
            self._subs[topic] = []
            topicIsNew = True

        if inspect.ismethod(cb):
            ref = weakref.WeakMethod(cb, lambda _: self._cbDeleted(topic))
        else:
            ref = weakref.ref(cb, lambda _: self._cbDeleted(topic))

        self._subs[topic].append(ref)

        self.dumpSubs()

        if topicIsNew:  # don't sub until our handler is added, in case retained msgs come (I don't know if they can or not)
            if self.connected:
                log.info(f'  we"re connected so lets add a real sub to {topic!r}')
                await self.mqttSub(topic)
            else:
                log.info(f"  connection wait, trying to subscribe to {topic!r}")
                self.pendingTopicSubs.add(topic)

    def dumpSubs(self):
        log.info('  now _subs is')
        for k in self._subs:
            self._subs[k] = [v for v in self._subs[k] if v() is not None]
            log.info(f'    - {k} {self._subs[k]}')

    def _cbDeleted(self, topic: str):
        log.info(f'cb removed under {topic}')
        if topic not in self._subs:
            return
        self._subs[topic] = [v for v in self._subs[topic] if v() is not None]
        if not self._subs[topic]:
            log.info(f'sohuld unsub {topic}')
            asyncio.create_task(self.mqttUnsub(topic))
            del self._subs[topic]

    async def onMqttConnected(self):
        log.info(f'mqtt connected. Make {len(self.pendingTopicSubs)}  pending subs')
        self.connected = True
        for p in self.pendingTopicSubs:
            await self.mqttSub(p)

        log.info('done with pending subs')
        self.pendingTopicSubs = set()

    def onMessage(self, message: aiomqtt.Message):
        topic = message.topic.value
        for cbRef in self._subs.get(topic, set()):
            cb = cbRef()
            if cb is None:
                raise ValueError("we should have pruned this sub already")
            try:
                cb(time.time(), cast(bytes, message.payload).decode('utf-8'))
            except Exception:
                log.error(f"in callback for {topic=}", exc_info=True)


class MqttIo:

    def __init__(self):
        self.devices = []
        client_id = "light-bridge"
        log.info('starting mqtt task')
        MQTT_CONNECTED.set(0)
        self.client = aiomqtt.Client('mqtt2.bigasterisk.com', identifier=client_id)
        self.subs = CurrentSubs(self.client.subscribe, self.client.unsubscribe)
        self._task = asyncio.create_task(self._run())
        log.info('started mqtt task')

    def assertRunning(self):
        if self._task.done():
            raise ValueError("Mqtt task is not running")

    async def _run(self):
        try:
            await self._connectAndRead()
        except aiomqtt.MqttError as e:
            MQTT_CONNECTED.set(0)
            log.error(e, exc_info=True)
            os.abort()

    async def _connectAndRead(self):
        async with self.client:
            await self.subs.onMqttConnected()
            MQTT_CONNECTED.set(1)
            async for message in self.client.messages:
                self.subs.onMessage(message)

    async def subscribe(self, topic: str, cb: Callable[[float, str], None]):
        """when a messages comes on this topic, call cb with the time and payload.
        """
        await self.subs.subscribe(topic, cb)

    async def publish(self, topic: str, msg: str):
        '''best effort'''
        if not self.subs.connected:
            log.error('publish ignored- not connected', exc_info=True)
            return

        log.info(f'client.publish {topic=} {msg=}')
        await self.client.publish(topic, msg)