changeset 2197:698858173947

rewrite typedValue to support 'T | None' and to test returned graph node types better
author drewp@bigasterisk.com
date Mon, 22 May 2023 00:54:22 -0700
parents 5ee5e17a3fd3
children ae38d21e6f6b
files light9/typedgraph.py light9/typedgraph_test.py
diffstat 2 files changed, 99 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/light9/typedgraph.py	Sun May 21 17:00:25 2023 -0700
+++ b/light9/typedgraph.py	Mon May 22 00:54:22 2023 -0700
@@ -1,9 +1,7 @@
-import decimal
-from types import UnionType
-from typing import Type, TypeVar, cast, get_args
+from typing import List, Type, TypeVar, cast, get_args
 
 from rdfdb.syncedgraph.syncedgraph import SyncedGraph
-from rdflib import Graph
+from rdflib import XSD, Graph, Literal, URIRef
 from rdflib.term import Node
 
 # todo: this ought to just require a suitable graph.value method
@@ -16,32 +14,74 @@
     """graph had a value, but it does not safely convert to any of the requested types"""
 
 
+def _expandUnion(t: Type) -> List[Type]:
+    if hasattr(t, '__args__'):
+        return list(get_args(t))
+    return [t]
+
+
 def _typeIncludes(t1: Type, t2: Type) -> bool:
     """same as issubclass but t1 can be a NewType"""
-    # if hasattr(t1, '__supertype__'):
-    #     t1 = t1.__supertype__
-    print(f'{isinstance(t1,  UnionType)=}')
-    if isinstance(t1, UnionType):
-        print(f" i see {t1} is union")
-        return any(_typeIncludes(t, t2) for t in get_args(t1))
-    # print('iss', t1, t2,     isinstance(t1,t2))
+    if t2 is None:
+        t2 = type(None)
+    if t1 == t2:
+        return True
+
+    if getattr(t1, '__supertype__', None) == t2:
+        return True
+
+    ts = _expandUnion(t1)
+    if len(ts) > 1:
+        return any(_typeIncludes(t, t2) for t in ts)
     # if t1 is float:
     #     return float in get_args(t2)
-    return issubclass(t1, t2)
+    print(f'down to {t1} {t2}')
+
+    return False
 
 
-def typedValue(objType: Type[_ObjType], graph: EitherGraph, subj: Node, pred: Node) -> _ObjType:
+def _convLiteral(objType: Type[_ObjType], x: Literal) -> _ObjType:
+    if _typeIncludes(objType, Literal):
+        return cast(objType, x)
+
+    for outType, dtypes in [
+        (float, (XSD['integer'], XSD['double'], XSD['decimal'])),
+        (int, (XSD['integer'],)),
+        (str, ()),
+    ]:
+        for t in _expandUnion(objType):
+            if _typeIncludes(t, outType) and (not dtypes or x.datatype in dtypes):
+                # e.g. user wants float and we have xsd:double
+                return cast(objType, outType(x.toPython()))
+    raise ConversionError
+
+
+def typedValue(objType: Type[_ObjType], graph: EitherGraph, subj: Node, pred: URIRef) -> _ObjType:
     """graph.value(subj, pred) with a given return type.
     If objType is not an rdflib.Node, we toPython() the value.
 
     Allow objType to include None if you want a None return for not-found.
     """
+    if objType is None:
+        raise TypeError('must allow non-None result type')
     obj = graph.value(subj, pred)
     if obj is None:
-        if type(None) in get_args(objType):
-            return None
+        if _typeIncludes(objType, None):
+            return cast(objType, None)
         raise ValueError(f'No obj for {subj=} {pred=}')
-    conv = obj  #if _typeIncludes(objType, Node) else obj.toPython()
+
+    ConvFrom: Type[Node] = type(obj)
+    ConvTo = objType
+    try:
+        if ConvFrom == URIRef and _typeIncludes(ConvTo, URIRef):
+            conv = obj
+        elif ConvFrom == Literal:
+            conv = _convLiteral(objType, cast(Literal, obj))
+        else:
+            # e.g. BNode is not handled yet
+            raise ConversionError
+    except ConversionError:
+        raise ConversionError(f'graph contains {type(obj)}, caller requesting {objType}')
     # if objType is float and isinstance(conv, decimal.Decimal):
     #     conv = float(conv)
     return cast(objType, conv)
\ No newline at end of file
--- a/light9/typedgraph_test.py	Sun May 21 17:00:25 2023 -0700
+++ b/light9/typedgraph_test.py	Mon May 22 00:54:22 2023 -0700
@@ -4,8 +4,8 @@
 from rdflib import Graph, Literal, URIRef
 
 from light9.mock_syncedgraph import MockSyncedGraph
-from light9.namespaces import L9
-from light9.typedgraph import ConversionError, typedValue
+from light9.namespaces import L9, XSD
+from light9.typedgraph import ConversionError, _typeIncludes, typedValue
 
 g = cast(
     Graph,
@@ -13,10 +13,11 @@
     @prefix : <http://light9.bigasterisk.com/> .
     :subj
         :uri :c;
-        # see https://w3c.github.io/N3/spec/#literals for syntaxes
-        :float1 0;
-        :float2 0.0;
-        :float3 1.0e1;
+        # see https://w3c.github.io/N3/spec/#literals for syntaxes.
+        :int 0;
+        :float1 0.0;
+        :float2 1.0e1;
+        :float3 0.5;
         :color "#ffffff"^^:hexColor;
         :definitelyAString "hello" .
 '''))
@@ -24,15 +25,44 @@
 subj = L9['subj']
 
 
+class TestTypeIncludes:
+
+    def test_includesItself(self):
+        assert _typeIncludes(str, str)
+
+    def test_includesUnionMember(self):
+        assert _typeIncludes(int | str, str)
+
+    def test_notIncludes(self):
+        assert not _typeIncludes(int | str, None)
+
+    def test_explicitOptionalWorks(self):
+        assert _typeIncludes(Optional[int], None)
+
+    def test_3WayUnionWorks(self):
+        assert _typeIncludes(int | str | float, int)
+
+
 class TestTypedValueReturnsBasicTypes:
 
     def test_getsUri(self):
         assert typedValue(URIRef, g, subj, L9['uri']) == L9['c']
 
-    def test_getsFloats(self):
+    def test_getsNumerics(self):
+        assert typedValue(float, g, subj, L9['int']) == 0
         assert typedValue(float, g, subj, L9['float1']) == 0
-        assert typedValue(float, g, subj, L9['float2']) == 0
-        assert typedValue(float, g, subj, L9['float3']) == 10
+        assert typedValue(float, g, subj, L9['float2']) == 10
+        assert typedValue(float, g, subj, L9['float3']) == 0.5
+
+        assert typedValue(int, g, subj, L9['int']) == 0
+        # These retrieve rdf floats that happen to equal
+        # ints, but no one should be relying on that.
+        with pytest.raises(ConversionError):
+            typedValue(int, g, subj, L9['float1'])
+        with pytest.raises(ConversionError):
+            typedValue(int, g, subj, L9['float2'])
+        with pytest.raises(ConversionError):
+            typedValue(int, g, subj, L9['float3'])
 
     def test_getsString(self):
         tv = typedValue(str, g, subj, L9['color'])
@@ -41,7 +71,7 @@
     def test_getsLiteral(self):
         tv = typedValue(Literal, g, subj, L9['float2'])
         assert type(tv) == Literal
-        assert tv.datatype == 'todo'
+        assert tv.datatype == XSD['double']
 
         tv = typedValue(Literal, g, subj, L9['color'])
         assert type(tv) == Literal
@@ -84,15 +114,11 @@
 
     def test_castsUri(self):
         DeviceUri = NewType('DeviceUri', URIRef)
-        tv = typedValue(DeviceUri, g, subj, L9['uri'])
-        assert type(tv) == DeviceUri
-        assert tv == DeviceUri(L9['c'])
+        assert typedValue(DeviceUri, g, subj, L9['uri']) == DeviceUri(L9['c'])
 
     def test_castsLiteralToNewType(self):
         HexColor = NewType('HexColor', str)
-        tv = typedValue(HexColor, g, subj, L9['color'])
-        assert type(tv) == HexColor
-        assert tv == HexColor('#ffffff')
+        assert typedValue(HexColor, g, subj, L9['color']) == HexColor('#ffffff')
 
 
 class TestTypedValueAcceptsUnionTypes:
@@ -111,7 +137,7 @@
             typedValue(float | URIRef, g, subj, L9['color'])
 
     def test_combinesWithNone(self):
-        assert typedValue(float | str | None, g, subj, L9['uri']) == L9['c']
+        assert typedValue(float | URIRef | None, g, subj, L9['uri']) == L9['c']
 
     def test_combinedWithNewType(self):
         HexColor = NewType('HexColor', str)