view service/mqtt_to_rdf/mqtt_to_rdf.py @ 1726:7d3797ed6681

rough port to starlette and reactivex
author drewp@bigasterisk.com
date Tue, 20 Jun 2023 23:14:28 -0700
parents 2085ed9cfcc4
children 23e6154e6c11
line wrap: on
line source

"""
Subscribe to mqtt topics; generate RDF statements.
"""
import asyncio
import json
import logging
import time
from dataclasses import dataclass
from typing import Callable, List, Set, Tuple, Union, cast
from mqttrx import MqttClient
from reactivex import Observable, empty, operators
import reactivex
from reactivex.scheduler.eventloop.asyncioscheduler import AsyncIOScheduler

from patchablegraph import PatchableGraph
from patchablegraph.handler import GraphEvents, StaticGraph
from prometheus_client import Counter
from rdflib import RDF, XSD, Graph, Literal, Namespace, URIRef
from rdflib.graph import ConjunctiveGraph
from rdflib.term import Node
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.staticfiles import StaticFiles
from reactivex.typing import Mapper
from starlette_exporter import handle_metrics
from starlette_exporter.middleware import PrometheusMiddleware

from button_events import button_events
from inference import Inference
from mqtt_message import graphFromMessage

log = logging.getLogger()
ROOM = Namespace('http://projects.bigasterisk.com/room/')
MESSAGES_SEEN = Counter('mqtt_messages_seen', '')
collectors = {}

def logGraph(debug: Callable, label: str, graph: Graph):
    n3 = cast(bytes, graph.serialize(format="n3"))
    debug(label + ':\n' + n3.decode('utf8'))


def appendLimit(lst, elem, n=10):
    del lst[:len(lst) - n + 1]
    lst.append(elem)


def parseDurationLiteral(lit: Literal) -> float:
    if lit.endswith('s'):
        return float(lit.split('s')[0])
    raise NotImplementedError(f'duration literal: {lit}')


@dataclass
class StreamPipelineStep:
    uri: URIRef  # a :MqttStatementSource
    config: Graph

    def makeOutputStream(self, inStream: Observable) -> Observable:
        return inStream


class Filters(StreamPipelineStep):

    def makeOutputStream(self, inStream: Observable) -> Observable:
        jsonEq = self.config.value(self.uri, ROOM['filterPayloadJsonEquals'])
        if jsonEq:
            required = json.loads(jsonEq.toPython())

            def eq(jsonBytes):
                msg = json.loads(jsonBytes.decode('utf8'))
                return msg == required

            outStream = operators.filter(eq)(inStream)
        else:
            outStream = inStream
        return outStream


class Parser(StreamPipelineStep):

    def makeOutputStream(self, inStream: Observable) -> Observable:
        parser = self.getParser()
        return parser(inStream)

    def getParser(self) -> Callable[[Observable], Observable]:
        parserType = cast(URIRef, self.config.value(self.uri, ROOM['parser']))
        func = self.getParserFunc(parserType)
        return operators.map(cast(Mapper, func))

    def getParserFunc(self, parserType: URIRef) -> Callable[[bytes], Node]:
        if parserType == XSD.double:
            return lambda v: Literal(float(v))
        elif parserType == ROOM['tagIdToUri']:
            return self.tagIdToUri
        elif parserType == ROOM['onOffBrightness']:
            return lambda v: Literal(0.0 if v == b'OFF' else 1.0)
        elif parserType == ROOM['jsonBrightness']:
            return self.parseJsonBrightness
        elif ROOM['ValueMap'] in self.config.objects(parserType, RDF.type):
            return lambda v: self.remap(parserType, v.decode('utf8'))
        elif parserType == ROOM['rfCode']:
            return self.parseJsonRfCode
        elif parserType == ROOM['tradfri']:
            return self.parseTradfriMessage
        else:
            raise NotImplementedError(parserType)

    def tagIdToUri(self, value: bytes) -> URIRef:
        justHex = value.decode('utf8').replace('-', '').lower()
        int(justHex, 16)  # validate
        return URIRef(f'http://bigasterisk.com/rfidCard/{justHex}')

    def parseJsonBrightness(self, mqttValue: bytes):
        msg = json.loads(mqttValue.decode('utf8'))
        return Literal(float(msg['brightness'] / 255) if msg['state'] == 'ON' else 0.0)

    def remap(self, parser, valueStr: str) -> Node:
        g = self.config
        value = Literal(valueStr)
        for entry in g.objects(parser, ROOM['map']):
            if value == g.value(entry, ROOM['from']):
                to_ = g.value(entry, ROOM['to'])
                if not isinstance(to_, Node):
                    raise TypeError(f'{to_=}')
                return to_
        raise KeyError(value)

    def parseJsonRfCode(self, mqttValue: bytes):
        msg = json.loads(mqttValue.decode('utf8'))
        return Literal('%08x%08x' % (msg['code0'], msg['code1']))

    def parseTradfriMessage(self, mqttValue: bytes) -> Node:
        log.info(f'trad {mqttValue}')
        return Literal('todo')


class Converters(StreamPipelineStep):

    def makeOutputStream(self, inStream: Observable) -> Observable:
        out = inStream
        g = self.config
        for conv in g.items(g.value(self.uri, ROOM['conversions'])):
            out = self.conversionStep(conv)(out)
        return out

    def conversionStep(self, conv: Node) -> Callable[[Observable], Observable]:
        g = self.config
        if conv == ROOM['celsiusToFarenheit']:

            return operators.map(cast(Mapper, self.c2f))
        elif g.value(conv, ROOM['ignoreValueBelow'], default=None) is not None:
            threshold = cast(Literal, g.value(conv, ROOM['ignoreValueBelow'])).toPython()
            return operators.filter(lambda value: cast(Literal, value).toPython() >= threshold)
        elif conv == ROOM['buttonPress']:
            return button_events(min_hold_sec=1.0, release_after_sec=1.0)
        else:
            raise NotImplementedError(conv)

    def c2f(self, value: Literal) -> Node:
        return Literal(round(cast(float, value.toPython()) * 1.8 + 32, 2))


class Rdfizer(StreamPipelineStep):

    def makeOutputStream(self, inStream: Observable) -> Observable:
        plans = list(self.config.objects(self.uri, ROOM['graphStatements']))
        log.debug(f'{self.uri=} has {len(plans)=}')
        if not plans:
            return empty()
        outputQuadsSets = reactivex.combine_latest(*[self.makeQuads(inStream, plan) for plan in plans])
        return outputQuadsSets

    def makeQuads(self, inStream: Observable, plan: Node) -> Observable:

        def quadsFromValue(valueNode):
            return set([(self.uri, self.config.value(plan, ROOM['outputPredicate']), valueNode, self.uri)])

        def emptyQuads(element) -> Set[Tuple]:
            return set([])

        quads = operators.map(cast(Mapper, quadsFromValue))(inStream)

        dur = self.config.value(plan, ROOM['statementLifetime'])
        if dur is not None:
            sec = parseDurationLiteral(dur)
            loop = AsyncIOScheduler(asyncio.get_event_loop())
            quads = quads.pipe(
                operators.debounce(sec, loop),
                operators.map(cast(Mapper, emptyQuads)),
                operators.merge(quads),
            )

        return quads


def truncTime():
    return round(time.time(), 3)


def tightN3(node: Union[URIRef, Literal]) -> str:
    return node.n3().replace('http://www.w3.org/2001/XMLSchema#', 'xsd:')


def serializeWithNs(graph: Graph, hidePrefixes=False) -> str:
    graph.bind('', ROOM)
    n3 = cast(bytes, graph.serialize(format='n3')).decode('utf8')
    if hidePrefixes:
        n3 = ''.join(line for line in n3.splitlines(keepends=True) if not line.strip().startswith('@prefix'))
    return n3


class EmptyTopicError(ValueError):
    pass


class MqttStatementSource:

    def __init__(self, uri: URIRef, topic: bytes, masterGraph: PatchableGraph, mqtt, internalMqtt, debugPageData, 
    # influxExport: InfluxExporter,
                 inference: Inference):
        self.uri = uri

        self.masterGraph = masterGraph
        self.debugPageData = debugPageData
        self.mqtt = mqtt  # deprecated
        self.internalMqtt = internalMqtt
        # self.influxExport = influxExport
        self.inference = inference

        self.mqttTopic = topic
        if self.mqttTopic == b'':
            raise EmptyTopicError(f"empty topic for {uri=}")
        log.debug(f'new mqttTopic {self.mqttTopic}')

        self.debugSub = {
            'topic': self.mqttTopic.decode('ascii'),
            'recentMessageGraphs': [],
            'recentMetrics': [],
            'currentOutputGraph': {
                't': 1,
                'n3': "(n3)"
            },
        }
        self.debugPageData['subscribed'].append(self.debugSub)

        rawBytes: Observable = self.subscribeMqtt(self.mqttTopic)
        rawBytes.subscribe(on_next=self.countIncomingMessage)

        rawBytes.subscribe(self.onMessage)

    def onMessage(self, raw: bytes):
        g = graphFromMessage(self.uri, self.mqttTopic, raw)
        logGraph(log.debug, 'message graph', g)
        appendLimit(
            self.debugSub['recentMessageGraphs'],
            {  #
                't': truncTime(),
                'n3': serializeWithNs(g, hidePrefixes=True)
            })

        implied = self.inference.infer(g)
        self.updateMasterGraph(implied)

    def subscribeMqtt(self, topic: bytes):
        # goal is to get everyone on the internal broker and eliminate this
        mqtt = self.internalMqtt if topic.startswith(b'frontdoorlock') else self.mqtt
        return mqtt.subscribe(topic)

    def countIncomingMessage(self, msg: bytes):
        self.debugPageData['messagesSeen'] += 1
        MESSAGES_SEEN.inc()

    def updateMasterGraph(self, newGraph):
        log.debug(f'{self.uri} update to {len(newGraph)} statements')

        cg = ConjunctiveGraph()
        for stmt in newGraph:
            cg.add(stmt + (self.uri,))
            # meas = stmt[0].split('/')[-1]
            # if meas.startswith('airQuality'):
            #     where_prefix, type_ = meas[len('airQuality'):].split('door')
            #     where = where_prefix + 'door'
            #     metric = 'air'
            #     tags = {'loc': where.lower(), 'type': type_.lower()}
            #     val = stmt[2].toPython()
            #     if metric not in collectors:
            #         collectors[metric] = Gauge(metric, 'measurement', labelnames=tags.keys())

            #     collectors[metric].labels(**tags).set(val)

        self.masterGraph.patchSubgraph(self.uri, cg)
        self.debugSub['currentOutputGraph']['n3'] = serializeWithNs(cg, hidePrefixes=True)



# class DebugPageData(cyclone.sse.SSEHandler):

#     def __init__(self, application, request):
#         cyclone.sse.SSEHandler.__init__(self, application, request)
#         self.lastSent = None

#     def watch(self):
#         try:
#             dpd = self.settings.debugPageData
#             js = json.dumps(dpd, sort_keys=True)
#             if js != self.lastSent:
#                 log.debug('sending dpd update')
#                 self.sendEvent(message=js.encode('utf8'))
#                 self.lastSent = js
#         except Exception:
#             import traceback
#             traceback.print_exc()

#     def bind(self):
#         self.loop = task.LoopingCall(self.watch)
#         self.loop.start(1, now=True)

#     def unbind(self):
#         self.loop.stop()



class RunState:
    """this is rebuilt upon every config reload"""
    def __init__(self,
                 expandedConfigPatchableCopy: PatchableGraph,  # for output and display
                 masterGraph: PatchableGraph,  # current sensor outputs
                 mqtt: MqttClient,
                 internalMqtt: MqttClient,
                 #  influxExport: InfluxExporter,
                 inference: Inference):
        loadedConfig = ConjunctiveGraph()
        loadedConfig.parse('conf/rules.n3', format='n3')

        inference.setRules(loadedConfig)
        self.expandedConfig = inference.infer(loadedConfig)
        self.expandedConfig += inference.nonRuleStatements()

        ecWithQuads = ConjunctiveGraph()
        for s, p, o in self.expandedConfig:
            ecWithQuads.add((s, p, o, URIRef('/config')))
        expandedConfigPatchableCopy.setToGraph(ecWithQuads)

        self.srcs = []
        srcs = cast(List[URIRef], list(self.expandedConfig.subjects(RDF.type, ROOM['MqttStatementSource'])))
        srcs.sort(key=str)
        for src in srcs:
            log.info(f'setup source {src=}')
            try:
                self.srcs.append(
                    MqttStatementSource(src, self.topicFromConfig(self.expandedConfig, src),
                                        masterGraph, mqtt=mqtt, internalMqtt=internalMqtt, 
                                        debugPageData={},#debugPageData,
                                        # influxExport=influxExport, 
                                        inference=inference))
            except EmptyTopicError:
                continue
        log.info(f'set up {len(self.srcs)} sources')

    def topicFromConfig(self, config, src) -> bytes:
        topicParts = list(config.items(config.value(src, ROOM['mqttTopic'])))
        return b'/'.join(t.encode('ascii') for t in topicParts)

def main():
    
    logging.getLogger('mqtt').setLevel(logging.DEBUG)
    logging.getLogger('mqtt_client').setLevel(logging.DEBUG)
    logging.getLogger('infer').setLevel(logging.DEBUG)
    logging.getLogger('cbind').setLevel(logging.DEBUG)
    log.setLevel(logging.DEBUG)
    log.info('log start')

    masterGraph = PatchableGraph()
    inference = Inference()

    brokerHost = 'mosquitto-frontdoor.default.svc.cluster.local'
    brokerPort = 10210

    mqtt = MqttClient(clientId='mqtt_to_rdf', brokerHost='mosquitto-ext.default.svc.cluster.local', brokerPort=1883)  # deprecated
    internalMqtt = MqttClient(clientId='mqtt_to_rdf', brokerHost=brokerHost, brokerPort=brokerPort)

    debugPageData = {
        # schema in index.ts
        'server': f'{brokerHost}:{brokerPort}',
        'messagesSeen': 0,
        'subscribed': [],
        "rules": "",
        "rulesInferred": "",
    }

    expandedConfigPatchableCopy = PatchableGraph()

    runState = RunState(expandedConfigPatchableCopy, masterGraph, mqtt, internalMqtt, inference)


    app = Starlette(
        routes=[
                          Route("/", StaticFiles(directory='.'), name='index.html'),
                        #   Route("/build/(bundle.js)", cyclone.web.StaticFileHandler, {"path": "build"}),
                          Route("/graph/config", StaticGraph(expandedConfigPatchableCopy)),
                          Route("/graph/mqtt", StaticGraph(masterGraph)),
                          Route("/graph/mqtt/events", GraphEvents(masterGraph)),
                        #   Route('/debugPageData', DebugPageData),
        ])
    
    app.add_middleware(PrometheusMiddleware, app_name='environment')
    app.add_route("/metrics", handle_metrics)
    return app
app = main()