changeset 1:2a288d2cb88c

add unread_to_mqtt bridge
author drewp@bigasterisk.com
date Tue, 11 Feb 2025 19:20:47 -0800
parents 96f842f12121
children 6fc2c741f1a6
files bots/bigastbot.py bots/unread_to_mqtt.py
diffstat 2 files changed, 112 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/bots/bigastbot.py	Tue Jan 28 23:30:02 2025 -0800
+++ b/bots/bigastbot.py	Tue Feb 11 19:20:47 2025 -0800
@@ -1,3 +1,6 @@
+import asyncio
+from typing import AsyncGenerator, cast
+
 import zulip
 from kubernetes import client, config
 
@@ -35,3 +38,19 @@
                    topic=topic,
                    content=content)
         return self.zulip_client.send_message(msg)
+
+    async def get_registration_and_events(
+            self, **register_kw) -> AsyncGenerator[dict, None]:
+        """yields the registration response, then the events as they come"""
+        reg = self.zulip_client.register(**register_kw)
+        yield reg
+
+        last = reg['last_event_id']
+        while True:
+            update = self.zulip_client.get_events(queue_id=reg['queue_id'],
+                                                  last_event_id=last)
+            for ev in cast(list[dict], update['events']):
+                yield ev
+                last = max(last, ev['id'])
+
+            await asyncio.sleep(1)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/bots/unread_to_mqtt.py	Tue Feb 11 19:20:47 2025 -0800
@@ -0,0 +1,93 @@
+import asyncio
+import json
+import logging
+import sys
+from dataclasses import dataclass
+from typing import cast
+
+import aiomqtt
+
+from bigastbot import BigAstBot
+
+logging.basicConfig(level=logging.DEBUG,
+                    format='%(asctime)s %(levelname)s %(name)s %(message)s',
+                    datefmt='%Y-%m-%d %H:%M:%S')
+log = logging.getLogger()
+
+
+@dataclass
+class UnreadToMqtt:
+    email: str
+    mqtt: aiomqtt.Client
+
+    async def run(self):
+        while True:
+            try:
+                log.info(f'connecting to zulip as {self.email}')
+                bot = BigAstBot(email=self.email)
+                self.unread_msg_ids = set()
+                self.last_sent: int | None = None
+
+                async for ev in bot.get_registration_and_events(event_types=[
+                        'message',
+                        'update_message_flags',
+                ]):
+                    await self._update_unreads_with_event(ev)
+            except aiomqtt.MqttError:
+                raise
+            except Exception as e:
+                log.error(e)
+                await asyncio.sleep(1)
+                continue
+
+    async def _update_unreads_with_event(self, ev):
+        if 'unread_msgs' in ev:
+            # looks like registration response
+            self._on_registration_response(ev)
+        elif ev['type'] == 'message':
+            self._on_message_event(ev)
+        elif ev['type'] == 'update_message_flags':
+            self._on_flag_change_event(ev)
+
+        if self.last_sent != len(self.unread_msg_ids):
+            await self._send_to_mqtt(len(self.unread_msg_ids))
+
+    def _on_flag_change_event(self, ev):
+        log.debug("_on_flag_change_event: %s", ev)
+        if ev['flag'] == 'read':
+            for msg_id in ev['messages']:
+                self.unread_msg_ids.discard(msg_id)
+
+    def _on_message_event(self, ev):
+        log.debug("_on_message_event: %s", ev)
+        if 'read' not in ev['flags']:
+            self.unread_msg_ids.add(ev['message']['id'])
+
+    def _on_registration_response(self, ev):
+        log.debug("_on_registration_response: %s", ev)
+        for msg_type in ['pms', 'streams', 'huddles']:  # mentions?
+            for group in ev['unread_msgs'][msg_type]:
+                self.unread_msg_ids.update(group['unread_message_ids'])
+
+    async def _send_to_mqtt(self, num_unread):
+        await self.mqtt.publish(f'/zulip/unread/{self.email}',
+                                json.dumps({'all': num_unread}),
+                                retain=True)
+        self.last_sent = num_unread
+
+
+async def main():
+    user_emails = sys.argv[1:]
+    while True:
+        try:
+            log.info('connecting to mqtt')
+            async with aiomqtt.Client("mqtt2.bigasterisk.com") as client:
+                await asyncio.gather(*[
+                    UnreadToMqtt(email=user_email, mqtt=client).run()
+                    for user_email in user_emails
+                ])
+        except aiomqtt.MqttError:
+            continue
+
+
+asyncio.run(main())