Skip to content

Commit 78fcd4a

Browse files
committed
Clean up
* Differentiate between Generics and container type functions * Tie get_args and get_origin functions to Annotated import. * Rename function and add test to clarify forward ref use case
1 parent a494984 commit 78fcd4a

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ class User:
6666
)
6767

6868
import marshmallow
69-
import typing_extensions
7069
import typing_inspect
7170

7271
from marshmallow_dataclass.generic_resolver import (
@@ -78,9 +77,9 @@ class User:
7877
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute
7978

8079
if sys.version_info >= (3, 9):
81-
from typing import Annotated
80+
from typing import Annotated, get_args, get_origin
8281
else:
83-
from typing_extensions import Annotated
82+
from typing_extensions import Annotated, get_args, get_origin
8483

8584
if sys.version_info >= (3, 11):
8685
from typing import dataclass_transform
@@ -540,7 +539,7 @@ def _internal_class_schema(
540539
) -> Type[marshmallow.Schema]:
541540
schema_ctx = _schema_ctx_stack.top
542541

543-
if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10):
542+
if get_origin(clazz) is Annotated and sys.version_info < (3, 10):
544543
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
545544
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
546545
else:
@@ -597,7 +596,7 @@ def _internal_class_schema(
597596
(
598597
type_hints[field.name]
599598
if not is_generic_type(clazz)
600-
else _get_type_hint_of_generic_object(field.type, schema_ctx)
599+
else _resolve_forward_type_refs(field.type, schema_ctx)
601600
),
602601
_get_field_default(field),
603602
field.metadata,
@@ -659,8 +658,8 @@ def _field_by_supertype(
659658
)
660659

661660

662-
def _generic_type_add_any(typ: type) -> type:
663-
"""if typ is generic type without arguments, replace them by Any."""
661+
def _container_type_add_any(typ: type) -> type:
662+
"""if typ is container type without arguments, replace them by Any."""
664663
if typ is list or typ is List:
665664
typ = List[Any]
666665
elif typ is dict or typ is Dict:
@@ -676,18 +675,20 @@ def _generic_type_add_any(typ: type) -> type:
676675
return typ
677676

678677

679-
def _field_for_generic_type(
678+
def _field_for_container_type(
680679
typ: type,
681680
base_schema: Optional[Type[marshmallow.Schema]],
682681
**metadata: Any,
683682
) -> Optional[marshmallow.fields.Field]:
684683
"""
685-
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
684+
If the type is a container interface, resolve the arguments and construct the appropriate Field.
685+
686+
We use the term 'container' to differentiate from the Generic support
686687
"""
687-
origin = typing_extensions.get_origin(typ)
688-
arguments = typing_extensions.get_args(typ)
688+
origin = get_origin(typ)
689+
arguments = get_args(typ)
689690
if origin:
690-
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
691+
# Override base_schema.TYPE_MAPPING to change the class used for container types below
691692
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
692693

693694
if origin in (list, List):
@@ -749,18 +750,15 @@ def _field_for_annotated_type(
749750
"""
750751
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
751752
"""
752-
origin = typing_extensions.get_origin(typ)
753-
arguments = typing_extensions.get_args(typ)
753+
origin = get_origin(typ)
754+
arguments = get_args(typ)
754755
if origin and origin is Annotated:
755756
marshmallow_annotations = [
756757
arg
757758
for arg in arguments[1:]
758759
if _is_marshmallow_field(arg)
759760
# Support `CustomGenericField[mf.String]`
760-
or (
761-
is_generic_type(arg)
762-
and _is_marshmallow_field(typing_extensions.get_origin(arg))
763-
)
761+
or (is_generic_type(arg) and _is_marshmallow_field(get_origin(arg)))
764762
]
765763
if marshmallow_annotations:
766764
if len(marshmallow_annotations) > 1:
@@ -782,7 +780,7 @@ def _field_for_union_type(
782780
base_schema: Optional[Type[marshmallow.Schema]],
783781
**metadata: Any,
784782
) -> Optional[marshmallow.fields.Field]:
785-
arguments = typing_extensions.get_args(typ)
783+
arguments = get_args(typ)
786784
if typing_inspect.is_union_type(typ):
787785
if typing_inspect.is_optional_type(typ):
788786
metadata["allow_none"] = metadata.get("allow_none", True)
@@ -886,8 +884,8 @@ def _field_for_schema(
886884
if predefined_field:
887885
return predefined_field
888886

889-
# Generic types specified without type arguments
890-
typ = _generic_type_add_any(typ)
887+
# Container types (generics like List) specified without type arguments
888+
typ = _container_type_add_any(typ)
891889

892890
# Base types
893891
field = _field_by_type(typ, base_schema)
@@ -900,7 +898,7 @@ def _field_for_schema(
900898

901899
# i.e.: Literal['abc']
902900
if typing_inspect.is_literal_type(typ):
903-
arguments = typing_extensions.get_args(typ)
901+
arguments = get_args(typ)
904902
return marshmallow.fields.Raw(
905903
validate=(
906904
marshmallow.validate.Equal(arguments[0])
@@ -912,7 +910,7 @@ def _field_for_schema(
912910

913911
# i.e.: Final[str] = 'abc'
914912
if typing_inspect.is_final_type(typ):
915-
arguments = typing_extensions.get_args(typ)
913+
arguments = get_args(typ)
916914
if arguments:
917915
subtyp = arguments[0]
918916
elif default is not marshmallow.missing:
@@ -953,10 +951,10 @@ def _field_for_schema(
953951
if union_field:
954952
return union_field
955953

956-
# Generic types
957-
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
958-
if generic_field:
959-
return generic_field
954+
# Container types
955+
container_field = _field_for_container_type(typ, base_schema, **metadata)
956+
if container_field:
957+
return container_field
960958

961959
# typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a
962960
# __supertype__ attribute
@@ -1034,9 +1032,7 @@ def is_generic_alias_of_dataclass(clazz: type) -> bool:
10341032
Check if given class is a generic alias of a dataclass, if the dataclass is
10351033
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
10361034
"""
1037-
return is_generic_alias(clazz) and dataclasses.is_dataclass(
1038-
typing_extensions.get_origin(clazz)
1039-
)
1035+
return is_generic_alias(clazz) and dataclasses.is_dataclass(get_origin(clazz))
10401036

10411037

10421038
def _get_type_hints(
@@ -1058,11 +1054,13 @@ def _get_type_hints(
10581054
return type_hints
10591055

10601056

1061-
def _get_type_hint_of_generic_object(
1057+
def _resolve_forward_type_refs(
10621058
obj,
10631059
schema_ctx: _SchemaContext,
10641060
) -> type:
1065-
"""typing.get_type_hints doesn't work with generic aliases, i.e.: A[int]. But this 'hack' works."""
1061+
"""
1062+
Resolve forward references, mainly applies to Generics i.e.: `A["int"]` -> `A[int]`
1063+
"""
10661064

10671065
class X:
10681066
x: obj # type: ignore[name-defined]

marshmallow_dataclass/generic_resolver.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
TypeVar,
1212
)
1313

14-
import typing_extensions
1514
import typing_inspect
1615

1716
if sys.version_info >= (3, 9):
18-
from typing import Annotated
17+
from typing import Annotated, get_args, get_origin
1918
else:
20-
from typing_extensions import Annotated
19+
from typing_extensions import Annotated, get_args, get_origin
2120

2221
_U = TypeVar("_U")
2322

@@ -81,15 +80,15 @@ class A(Generic[T]):
8180
``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*).
8281
"""
8382
is_generic = is_generic_type(clazz)
84-
type_arguments = typing_extensions.get_args(clazz)
83+
type_arguments = get_args(clazz)
8584
return is_generic and len(type_arguments) > 0
8685

8786

8887
def is_generic_type(clazz: type) -> bool:
8988
"""
9089
typing_inspect.is_generic_type explicitly ignores Union and Tuple
9190
"""
92-
origin = typing_extensions.get_origin(clazz)
91+
origin = get_origin(clazz)
9392
return origin is not Annotated and (
9493
(isinstance(clazz, type) and issubclass(clazz, Generic)) # type: ignore[arg-type]
9594
or isinstance(clazz, typing_inspect.typingGenericAlias)
@@ -108,7 +107,7 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
108107
# Loop in reversed order and iteratively resolve types
109108
for subclass in reversed(clazz.mro()):
110109
if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type]
111-
args = typing_extensions.get_args(subclass.__orig_bases__[0])
110+
args = get_args(subclass.__orig_bases__[0])
112111

113112
if parent_class and args_by_class.get(parent_class):
114113
subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = []
@@ -129,8 +128,8 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
129128

130129
# clazz itself is a generic alias i.e.: A[int]. So it hold the last types.
131130
if is_generic_alias(clazz):
132-
origin = typing_extensions.get_origin(clazz)
133-
args = typing_extensions.get_args(clazz)
131+
origin = get_origin(clazz)
132+
args = get_args(clazz)
134133
for (_arg, future), potential_type in zip(args_by_class[origin], args): # type: ignore[index]
135134
if not isinstance(potential_type, TypeVar):
136135
future.set_result(potential_type)
@@ -154,7 +153,7 @@ def _replace_typevars(
154153
resolved_generics[arg].result() if arg in resolved_generics else arg
155154
)
156155
)
157-
for arg in typing_extensions.get_args(clazz)
156+
for arg in get_args(clazz)
158157
)
159158
)
160159

tests/test_generics.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,19 @@ class AnnotatedValue:
315315
with self.assertRaises(marshmallow.exceptions.ValidationError):
316316
schema.load({"emails": "notavalidemail"})
317317

318+
def test_generic_dataclass_with_forwardref(self):
319+
T = typing.TypeVar("T")
320+
321+
@dataclasses.dataclass
322+
class SimpleGeneric(typing.Generic[T]):
323+
data: T
324+
325+
schema_s = class_schema(SimpleGeneric["str"])()
326+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
327+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
328+
with self.assertRaises(ValidationError):
329+
schema_s.load({"data": 2})
330+
318331

319332
if __name__ == "__main__":
320333
unittest.main()

0 commit comments

Comments
 (0)