Files @ 8fc5da221688
Branch filter:

Location: light9/light9/effect/settings.py

drewp@bigasterisk.com
checkpoint show data
"""
Data structure and convertors for a table of (device,attr,value)
rows. These might be effect attrs ('strength'), device attrs ('rx'),
or output attrs (dmx channel).

BareSettings means (attr,value), no device.
"""
from __future__ import annotations

import decimal
import logging
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Sequence, Set, Tuple, cast

import numpy
from rdfdb.syncedgraph.syncedgraph import SyncedGraph
from rdflib import Literal, URIRef

from light9.collector.device import resolve
from light9.localsyncedgraph import LocalSyncedGraph
from light9.namespaces import L9, RDF
from light9.newtypes import (DeviceAttr, DeviceUri, EffectAttr, HexColor, VTUnion)

log = logging.getLogger('settings')


def parseHex(h):
    if h[0] != '#':
        raise ValueError(h)
    return [int(h[i:i + 2], 16) for i in (1, 3, 5)]


def parseHexNorm(h):
    return [x / 255 for x in parseHex(h)]


def toHex(rgbFloat: Sequence[float]) -> HexColor:
    assert len(rgbFloat) == 3
    scaled = (max(0, min(255, int(v * 255))) for v in rgbFloat)
    return HexColor('#%02x%02x%02x' % tuple(scaled))


def getVal(graph, subj):
    lit = graph.value(subj, L9['value']) or graph.value(subj, L9['scaledValue'])
    ret = lit.toPython()
    if isinstance(ret, decimal.Decimal):
        ret = float(ret)
    return ret


GraphType = SyncedGraph | LocalSyncedGraph


class _Settings:
    """
    Generic for DeviceUri/DeviceAttr/VTUnion or EffectClass/EffectAttr/VTUnion

    default values are 0 or '#000000'. Internal rep must not store zeros or some
    comparisons will break.
    """
    EntityType = DeviceUri
    AttrType = DeviceAttr

    def __init__(self, graph: GraphType, settingsList: List[Tuple[Any, Any, VTUnion]]):
        self.graph = graph  # for looking up all possible attrs
        self._compiled: Dict[self.__class__.EntityType, Dict[self.__class__.AttrType, VTUnion]] = {}
        for e, a, v in settingsList:
            attrVals = self._compiled.setdefault(e, {})
            if a in attrVals:
                v = resolve(
                    e,  # Hey, this is supposed to be DeviceClass (which is not convenient for us), but so far resolve() doesn't use that arg
                    a,
                    [attrVals[a], v])
            attrVals[a] = v
        # self._compiled may not be final yet- see _fromCompiled
        self._delZeros()

    @classmethod
    def _fromCompiled(cls, graph: GraphType, compiled: Dict[EntityType, Dict[AttrType, VTUnion]]):
        obj = cls(graph, [])
        obj._compiled = compiled
        obj._delZeros()
        return obj

    @classmethod
    def fromList(cls, graph: GraphType, others: List[_Settings]):
        """note that others may have multiple values for an attr"""
        self = cls(graph, [])
        for s in others:
            # if not isinstance(s, cls):
            #     raise TypeError(s)
            for row in s.asList():  # could work straight from s._compiled
                if row[0] is None:
                    raise TypeError('bad row %r' % (row,))
                dev, devAttr, value = row
                devDict = self._compiled.setdefault(dev, {})
                if devAttr in devDict:
                    existingVal: VTUnion = devDict[devAttr]
                    # raise NotImplementedError('fixme: dev is to be a deviceclass (but it is currently unused)')
                    value = resolve(dev, devAttr, [existingVal, value])
                devDict[devAttr] = value
        self._delZeros()
        return self

    @classmethod
    def _mult(cls, weight, row, dd) -> VTUnion:
        if isinstance(row[2], str):
            prev = parseHexNorm(dd.get(row[1], '#000000'))
            return toHex(prev + weight * numpy.array(parseHexNorm(row[2])))
        else:
            return dd.get(row[1], 0) + weight * row[2]

    @classmethod
    def fromBlend(cls, graph: GraphType, others: List[Tuple[float, _Settings]]):
        """others is a list of (weight, Settings) pairs"""
        out = cls(graph, [])
        for weight, s in others:
            if not isinstance(s, cls):
                raise TypeError(s)
            for row in s.asList():  # could work straight from s._compiled
                if row[0] is None:
                    raise TypeError('bad row %r' % (row,))
                dd = out._compiled.setdefault(row[0], {})

                newVal = cls._mult(weight, row, dd)
                dd[row[1]] = newVal
        out._delZeros()
        return out

    def _zeroForAttr(self, attr: AttrType) -> VTUnion:
        if attr == L9['color']:
            return HexColor('#000000')
        return 0.0

    def _delZeros(self):
        for dev, av in list(self._compiled.items()):
            for attr, val in list(av.items()):
                if val == self._zeroForAttr(attr):
                    del av[attr]
            if not av:
                del self._compiled[dev]

    def __hash__(self):
        itemed = tuple([(d, tuple([(a, v) for a, v in sorted(av.items())])) for d, av in sorted(self._compiled.items())])
        return hash(itemed)

    def __eq__(self, other):
        if not issubclass(other.__class__, self.__class__):
            raise TypeError("can't compare %r to %r" % (self.__class__, other.__class__))
        return self._compiled == other._compiled

    def __ne__(self, other):
        return not self == other

    def __bool__(self):
        return bool(self._compiled)

    def __repr__(self):
        words = []

        def accum():
            for dev, av in self._compiled.items():
                for attr, val in sorted(av.items()):
                    words.append('%s.%s=%s' % (dev.rsplit('/')[-1], attr.rsplit('/')[-1], val))
                    if len(words) > 5:
                        words.append('...')
                        return

        accum()
        if not words:
            words = ['(no settings)']
        return '<%s %s>' % (self.__class__.__name__, ' '.join(words))

    def getValue(self, dev: EntityType, attr: AttrType, defaultToZero=True):
        x = self._compiled.get(dev, {})
        if defaultToZero:
            return x.get(attr, self._zeroForAttr(attr))
        else:
            return x[attr]

    def _vectorKeys(self, deviceAttrFilter=None):
        """stable order of all the dev,attr pairs for this type of settings"""
        raise NotImplementedError

    def asList(self) -> List[Tuple[EntityType, AttrType, VTUnion]]:
        """old style list of (dev, attr, val) tuples"""
        out = []
        for dev, av in self._compiled.items():
            for attr, val in av.items():
                out.append((dev, attr, val))
        return out

    def devices(self) -> List[EntityType]:
        return list(self._compiled.keys())

    def toVector(self, deviceAttrFilter=None) -> List[float]:
        out: List[float] = []
        for dev, attr in self._vectorKeys(deviceAttrFilter):
            v = self.getValue(dev, attr)
            if attr == L9['color']:
                out.extend(parseHexNorm(v))
            else:
                if not isinstance(v, float):
                    raise TypeError(f'{attr=} value was {v=}')
                out.append(v)
        return out

    def byDevice(self) -> Iterable[Tuple[EntityType, _Settings]]:
        for dev, av in self._compiled.items():
            yield dev, self.__class__._fromCompiled(self.graph, {dev: av})

    def ofDevice(self, dev: EntityType) -> _Settings:
        return self.__class__._fromCompiled(self.graph, {dev: self._compiled.get(dev, {})})

    def distanceTo(self, other):
        diff = numpy.array(self.toVector()) - other.toVector()
        d = numpy.linalg.norm(diff, ord=None)
        log.info('distanceTo %r - %r = %g', self, other, d)
        return d

    def statements(self, subj: EntityType, ctx: URIRef, settingRoot: URIRef, settingsSubgraphCache: Set):
        """
        settingRoot can be shared across images (or even wider if you want)
        """
        # ported from live.coffee
        add = []
        for i, (dev, attr, val) in enumerate(self.asList()):
            # hopefully a unique number for the setting so repeated settings converge
            settingHash = hash((dev, attr, val)) % 9999999
            setting = URIRef('%sset%s' % (settingRoot, settingHash))
            add.append((subj, L9['setting'], setting, ctx))
            if setting in settingsSubgraphCache:
                continue

            scaledAttributeTypes = [L9['color'], L9['brightness'], L9['uv']]
            settingType = L9['scaledValue'] if attr in scaledAttributeTypes else L9['value']
            if not isinstance(val, URIRef):
                val = Literal(val)
            add.extend([
                (setting, L9['device'], dev, ctx),
                (setting, L9['deviceAttr'], attr, ctx),
                (setting, settingType, val, ctx),
            ])
            settingsSubgraphCache.add(setting)

        return add


class DeviceSettings(_Settings):
    EntityType = DeviceUri
    AttrType = DeviceAttr

    def _vectorKeys(self, deviceAttrFilter=None):
        with self.graph.currentState() as g:
            devs = set()  # devclass, dev
            for dc in g.subjects(RDF.type, L9['DeviceClass']):
                for dev in g.subjects(RDF.type, dc):
                    devs.add((dc, dev))

            keys = []
            for dc, dev in sorted(devs):
                for attr in sorted(g.objects(dc, L9['deviceAttr'])):
                    key = (dev, attr)
                    if deviceAttrFilter and key not in deviceAttrFilter:
                        continue
                    keys.append(key)
        return keys

    @classmethod
    def fromResource(cls, graph: GraphType, subj: EntityType):
        settingsList = []
        with graph.currentState() as g:
            for s in g.objects(subj, L9['setting']):
                d = g.value(s, L9['device'])
                da = g.value(s, L9['deviceAttr'])
                v = getVal(g, s)
                settingsList.append((d, da, v))
        return cls(graph, settingsList)

    @classmethod
    def fromVector(cls, graph, vector, deviceAttrFilter=None):
        compiled: Dict[DeviceSettings.EntityType, Dict[DeviceSettings.AttrType, VTUnion]] = {}
        i = 0
        for (d, a) in cls(graph, [])._vectorKeys(deviceAttrFilter):
            if a == L9['color']:
                v = toHex(vector[i:i + 3])
                i += 3
            else:
                v = vector[i]
                i += 1
            compiled.setdefault(d, {})[a] = v
        return cls._fromCompiled(graph, compiled)

    @classmethod
    def merge(cls, graph: SyncedGraph, others: List[DeviceSettings]) -> DeviceSettings:
        return cls.fromList(graph, cast(List[_Settings], others))


@dataclass
class BareEffectSettings:
    # settings for an already-selected EffectClass
    s: Dict[EffectAttr, VTUnion]

    def withStrength(self, strength: float) -> BareEffectSettings:
        out = self.s.copy()
        out[EffectAttr(L9['strength'])] = strength
        return BareEffectSettings(s=out)


class EffectSettings(_Settings):
    pass