diff --git a/light9/typedgraph.py b/light9/typedgraph.py --- a/light9/typedgraph.py +++ b/light9/typedgraph.py @@ -1,9 +1,7 @@ -import decimal -from types import UnionType -from typing import Type, TypeVar, cast, get_args +from typing import List, Type, TypeVar, cast, get_args from rdfdb.syncedgraph.syncedgraph import SyncedGraph -from rdflib import Graph +from rdflib import XSD, Graph, Literal, URIRef from rdflib.term import Node # todo: this ought to just require a suitable graph.value method @@ -16,32 +14,74 @@ class ConversionError(ValueError): """graph had a value, but it does not safely convert to any of the requested types""" +def _expandUnion(t: Type) -> List[Type]: + if hasattr(t, '__args__'): + return list(get_args(t)) + return [t] + + def _typeIncludes(t1: Type, t2: Type) -> bool: """same as issubclass but t1 can be a NewType""" - # if hasattr(t1, '__supertype__'): - # t1 = t1.__supertype__ - print(f'{isinstance(t1, UnionType)=}') - if isinstance(t1, UnionType): - print(f" i see {t1} is union") - return any(_typeIncludes(t, t2) for t in get_args(t1)) - # print('iss', t1, t2, isinstance(t1,t2)) + if t2 is None: + t2 = type(None) + if t1 == t2: + return True + + if getattr(t1, '__supertype__', None) == t2: + return True + + ts = _expandUnion(t1) + if len(ts) > 1: + return any(_typeIncludes(t, t2) for t in ts) # if t1 is float: # return float in get_args(t2) - return issubclass(t1, t2) + print(f'down to {t1} {t2}') + + return False -def typedValue(objType: Type[_ObjType], graph: EitherGraph, subj: Node, pred: Node) -> _ObjType: +def _convLiteral(objType: Type[_ObjType], x: Literal) -> _ObjType: + if _typeIncludes(objType, Literal): + return cast(objType, x) + + for outType, dtypes in [ + (float, (XSD['integer'], XSD['double'], XSD['decimal'])), + (int, (XSD['integer'],)), + (str, ()), + ]: + for t in _expandUnion(objType): + if _typeIncludes(t, outType) and (not dtypes or x.datatype in dtypes): + # e.g. user wants float and we have xsd:double + return cast(objType, outType(x.toPython())) + raise ConversionError + + +def typedValue(objType: Type[_ObjType], graph: EitherGraph, subj: Node, pred: URIRef) -> _ObjType: """graph.value(subj, pred) with a given return type. If objType is not an rdflib.Node, we toPython() the value. Allow objType to include None if you want a None return for not-found. """ + if objType is None: + raise TypeError('must allow non-None result type') obj = graph.value(subj, pred) if obj is None: - if type(None) in get_args(objType): - return None + if _typeIncludes(objType, None): + return cast(objType, None) raise ValueError(f'No obj for {subj=} {pred=}') - conv = obj #if _typeIncludes(objType, Node) else obj.toPython() + + ConvFrom: Type[Node] = type(obj) + ConvTo = objType + try: + if ConvFrom == URIRef and _typeIncludes(ConvTo, URIRef): + conv = obj + elif ConvFrom == Literal: + conv = _convLiteral(objType, cast(Literal, obj)) + else: + # e.g. BNode is not handled yet + raise ConversionError + except ConversionError: + raise ConversionError(f'graph contains {type(obj)}, caller requesting {objType}') # if objType is float and isinstance(conv, decimal.Decimal): # conv = float(conv) return cast(objType, conv) \ No newline at end of file diff --git a/light9/typedgraph_test.py b/light9/typedgraph_test.py --- a/light9/typedgraph_test.py +++ b/light9/typedgraph_test.py @@ -4,8 +4,8 @@ import pytest from rdflib import Graph, Literal, URIRef from light9.mock_syncedgraph import MockSyncedGraph -from light9.namespaces import L9 -from light9.typedgraph import ConversionError, typedValue +from light9.namespaces import L9, XSD +from light9.typedgraph import ConversionError, _typeIncludes, typedValue g = cast( Graph, @@ -13,10 +13,11 @@ g = cast( @prefix : . :subj :uri :c; - # see https://w3c.github.io/N3/spec/#literals for syntaxes - :float1 0; - :float2 0.0; - :float3 1.0e1; + # see https://w3c.github.io/N3/spec/#literals for syntaxes. + :int 0; + :float1 0.0; + :float2 1.0e1; + :float3 0.5; :color "#ffffff"^^:hexColor; :definitelyAString "hello" . ''')) @@ -24,15 +25,44 @@ g = cast( subj = L9['subj'] +class TestTypeIncludes: + + def test_includesItself(self): + assert _typeIncludes(str, str) + + def test_includesUnionMember(self): + assert _typeIncludes(int | str, str) + + def test_notIncludes(self): + assert not _typeIncludes(int | str, None) + + def test_explicitOptionalWorks(self): + assert _typeIncludes(Optional[int], None) + + def test_3WayUnionWorks(self): + assert _typeIncludes(int | str | float, int) + + class TestTypedValueReturnsBasicTypes: def test_getsUri(self): assert typedValue(URIRef, g, subj, L9['uri']) == L9['c'] - def test_getsFloats(self): + def test_getsNumerics(self): + assert typedValue(float, g, subj, L9['int']) == 0 assert typedValue(float, g, subj, L9['float1']) == 0 - assert typedValue(float, g, subj, L9['float2']) == 0 - assert typedValue(float, g, subj, L9['float3']) == 10 + assert typedValue(float, g, subj, L9['float2']) == 10 + assert typedValue(float, g, subj, L9['float3']) == 0.5 + + assert typedValue(int, g, subj, L9['int']) == 0 + # These retrieve rdf floats that happen to equal + # ints, but no one should be relying on that. + with pytest.raises(ConversionError): + typedValue(int, g, subj, L9['float1']) + with pytest.raises(ConversionError): + typedValue(int, g, subj, L9['float2']) + with pytest.raises(ConversionError): + typedValue(int, g, subj, L9['float3']) def test_getsString(self): tv = typedValue(str, g, subj, L9['color']) @@ -41,7 +71,7 @@ class TestTypedValueReturnsBasicTypes: def test_getsLiteral(self): tv = typedValue(Literal, g, subj, L9['float2']) assert type(tv) == Literal - assert tv.datatype == 'todo' + assert tv.datatype == XSD['double'] tv = typedValue(Literal, g, subj, L9['color']) assert type(tv) == Literal @@ -84,15 +114,11 @@ class TestTypedValueConvertsToNewTypes: def test_castsUri(self): DeviceUri = NewType('DeviceUri', URIRef) - tv = typedValue(DeviceUri, g, subj, L9['uri']) - assert type(tv) == DeviceUri - assert tv == DeviceUri(L9['c']) + assert typedValue(DeviceUri, g, subj, L9['uri']) == DeviceUri(L9['c']) def test_castsLiteralToNewType(self): HexColor = NewType('HexColor', str) - tv = typedValue(HexColor, g, subj, L9['color']) - assert type(tv) == HexColor - assert tv == HexColor('#ffffff') + assert typedValue(HexColor, g, subj, L9['color']) == HexColor('#ffffff') class TestTypedValueAcceptsUnionTypes: @@ -111,7 +137,7 @@ class TestTypedValueAcceptsUnionTypes: typedValue(float | URIRef, g, subj, L9['color']) def test_combinesWithNone(self): - assert typedValue(float | str | None, g, subj, L9['uri']) == L9['c'] + assert typedValue(float | URIRef | None, g, subj, L9['uri']) == L9['c'] def test_combinedWithNewType(self): HexColor = NewType('HexColor', str)