From bc46a233ea5c9040b27addb8fe8d98b6728bb647 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Tue, 11 Jan 2022 18:49:38 +0000 Subject: [PATCH 01/25] support generic dataclasses --- marshmallow_dataclass/__init__.py | 50 ++++++++++++++++++++++++++++--- tests/test_class_schema.py | 27 +++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index f82b13a..fc32c7c 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -453,7 +453,9 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): + if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass( + clazz + ): clazz = dataclasses.dataclass(clazz) if localns is None: if clazz_frame is None: @@ -523,8 +525,7 @@ def _internal_class_schema( schema_ctx.seen_classes[clazz] = class_name try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) + class_name, fields = _dataclass_name_and_fields(clazz) except TypeError: # Not a dataclass try: warnings.warn( @@ -582,7 +583,7 @@ def _internal_class_schema( if field.init or include_non_init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -996,6 +997,47 @@ def _get_field_default(field: dataclasses.Field): return field.default +def _is_generic_alias_of_dataclass(clazz: type) -> bool: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass( + typing_inspect.get_origin(clazz) + ) + + +# noinspection PyDataclass +def _dataclass_name_and_fields( + clazz: type, +) -> Tuple[str, Tuple[dataclasses.Field, ...]]: + if not _is_generic_alias_of_dataclass(clazz): + return clazz.__name__, dataclasses.fields(clazz) + + base_dataclass = typing_inspect.get_origin(clazz) + base_parameters = typing_inspect.get_parameters(base_dataclass) + type_arguments = typing_inspect.get_args(clazz) + params_to_args = dict(zip(base_parameters, type_arguments)) + non_generic_fields = [ # swap generic typed fields with types in given type arguments + ( + f.name, + params_to_args.get(f.type, f.type), + dataclasses.field( + default=f.default, + # ignoring mypy: https://github.com/python/mypy/issues/6910 + default_factory=f.default_factory, # type: ignore + init=f.init, + metadata=f.metadata, + ), + ) + for f in dataclasses.fields(base_dataclass) + ] + non_generic_dataclass = dataclasses.make_dataclass( + cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields + ) + return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass) + + def NewType( name: str, typ: Type[_U], diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 28185a4..19d4e99 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -457,6 +457,33 @@ class Meta: self.assertNotIn("no_init", class_schema(NoInit)().fields) self.assertIn("no_init", class_schema(Init)().fields) + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class Nested: + data: SimpleGeneric[int] + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_n = class_schema(Nested)() + self.assertEqual( + Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}}) + ) + self.assertEqual( + schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}} + ) + with self.assertRaises(ValidationError): + schema_n.load({"data": {"data": "str"}}) + if __name__ == "__main__": unittest.main() From db32163e5ce6f3494b0dbe3f175b253a1580b46a Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Mon, 28 Nov 2022 17:05:40 +0800 Subject: [PATCH 02/25] support nested generic dataclasses --- marshmallow_dataclass/__init__.py | 109 ++++++++++++++++-------------- tests/test_class_schema.py | 2 +- 2 files changed, 58 insertions(+), 53 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index fc32c7c..d1751c7 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -44,15 +44,9 @@ class User: import warnings from enum import Enum from functools import lru_cache, partial +from typing import Any, Callable, Dict, FrozenSet, Generic, List, Mapping +from typing import NewType as typing_NewType from typing import ( - Any, - Callable, - Dict, - FrozenSet, - Generic, - List, - Mapping, - NewType as typing_NewType, Optional, Sequence, Set, @@ -150,8 +144,7 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, -) -> Type[_U]: - ... +) -> Type[_U]: ... @overload @@ -164,8 +157,7 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, -) -> Callable[[Type[_U]], Type[_U]]: - ... +) -> Callable[[Type[_U]], Type[_U]]: ... # _cls should never be specified by keyword, so start it with an @@ -224,15 +216,13 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: @overload -def add_schema(_cls: Type[_U]) -> Type[_U]: - ... +def add_schema(_cls: Type[_U]) -> Type[_U]: ... @overload def add_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, -) -> Callable[[Type[_U]], Type[_U]]: - ... +) -> Callable[[Type[_U]], Type[_U]]: ... @overload @@ -241,8 +231,7 @@ def add_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, stacklevel: int = 1, -) -> Type[_U]: - ... +) -> Type[_U]: ... def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): @@ -293,8 +282,7 @@ def class_schema( *, globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, -) -> Type[marshmallow.Schema]: - ... +) -> Type[marshmallow.Schema]: ... @overload @@ -304,8 +292,7 @@ def class_schema( clazz_frame: Optional[types.FrameType] = None, *, globalns: Optional[Dict[str, Any]] = None, -) -> Type[marshmallow.Schema]: - ... +) -> Type[marshmallow.Schema]: ... def class_schema( @@ -463,7 +450,7 @@ def class_schema( if clazz_frame is not None: localns = clazz_frame.f_locals with _SchemaContext(globalns, localns): - return _internal_class_schema(clazz, base_schema) + return _internal_class_schema(clazz, base_schema, None) class _SchemaContext: @@ -509,10 +496,17 @@ def top(self) -> _U: _schema_ctx_stack = _LocalStack[_SchemaContext]() +def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + if _is_generic_alias_of_dataclass(clazz): + clazz = typing_inspect.get_origin(clazz) + return dataclasses.fields(clazz) + + @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> Type[marshmallow.Schema]: schema_ctx = _schema_ctx_stack.top @@ -525,7 +519,7 @@ def _internal_class_schema( schema_ctx.seen_classes[clazz] = class_name try: - class_name, fields = _dataclass_name_and_fields(clazz) + fields = _dataclass_fields(clazz) except TypeError: # Not a dataclass try: warnings.warn( @@ -540,7 +534,9 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema(created_dataclass, base_schema) + return _internal_class_schema( + created_dataclass, base_schema, generic_params_to_args + ) except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -556,6 +552,10 @@ def _internal_class_schema( # Determine whether we should include non-init fields include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False) + if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None: + generic_params_to_args = _generic_params_to_args(clazz) + + type_hints = _dataclass_type_hints(clazz, schema_ctx, generic_params_to_args) # Update the schema members to contain marshmallow fields instead of dataclass fields if sys.version_info >= (3, 9): @@ -577,13 +577,14 @@ def _internal_class_schema( _get_field_default(field), field.metadata, base_schema, + generic_params_to_args, ), ) for field in fields if field.init or include_non_init ) - schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -706,7 +707,7 @@ def _field_for_generic_type( ), ) return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): + if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=_field_for_schema(arguments[0], base_schema=base_schema), @@ -794,6 +795,7 @@ def field_for_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, # FIXME: delete typ_frame from API? typ_frame: Optional[types.FrameType] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -953,7 +955,7 @@ def _field_for_schema( nested_schema or forward_reference or _schema_ctx_stack.top.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME + or _internal_class_schema(typ, base_schema, generic_params_to_args) # type: ignore [arg-type] ) return marshmallow.fields.Nested(nested, **metadata) @@ -1007,35 +1009,38 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: ) -# noinspection PyDataclass -def _dataclass_name_and_fields( - clazz: type, -) -> Tuple[str, Tuple[dataclasses.Field, ...]]: - if not _is_generic_alias_of_dataclass(clazz): - return clazz.__name__, dataclasses.fields(clazz) - +def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]: base_dataclass = typing_inspect.get_origin(clazz) base_parameters = typing_inspect.get_parameters(base_dataclass) type_arguments = typing_inspect.get_args(clazz) - params_to_args = dict(zip(base_parameters, type_arguments)) - non_generic_fields = [ # swap generic typed fields with types in given type arguments - ( - f.name, - params_to_args.get(f.type, f.type), - dataclasses.field( - default=f.default, - # ignoring mypy: https://github.com/python/mypy/issues/6910 - default_factory=f.default_factory, # type: ignore - init=f.init, - metadata=f.metadata, - ), + return tuple(zip(base_parameters, type_arguments)) + + +def _dataclass_type_hints( + clazz: type, + schema_ctx: _SchemaContext = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, +) -> Mapping[str, type]: + if not _is_generic_alias_of_dataclass(clazz): + return get_type_hints( + clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns ) - for f in dataclasses.fields(base_dataclass) - ] - non_generic_dataclass = dataclasses.make_dataclass( - cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields + # dataclass is generic + generic_type_hints = get_type_hints( + typing_inspect.get_origin(clazz), + globalns=schema_ctx.globalns, + localns=schema_ctx.localns, ) - return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass) + generic_params_map = dict(generic_params_to_args if generic_params_to_args else {}) + + def _get_hint(_t: type) -> type: + if isinstance(_t, TypeVar): + return generic_params_map[_t] + return _t + + return { + field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items() + } def NewType( diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 19d4e99..c5b9537 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -14,7 +14,7 @@ from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass class TestClassSchema(unittest.TestCase): From 8b82276e7f81539318831457b318e58a4034da49 Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Fri, 9 Dec 2022 22:19:06 +1100 Subject: [PATCH 03/25] add test for repeated fields, fix __name__ attr for py<3.10 --- marshmallow_dataclass/__init__.py | 60 +++++++++++++++++++++++----- tests/test_class_schema.py | 65 +++++++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 18 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index d1751c7..4aa811e 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -584,7 +584,7 @@ def _internal_class_schema( if field.init or include_non_init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -602,6 +602,7 @@ def _field_by_supertype( newtype_supertype: Type, metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -632,6 +633,7 @@ def _field_by_supertype( metadata=metadata, default=default, base_schema=base_schema, + generic_params_to_args=generic_params_to_args, ) @@ -655,6 +657,7 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -667,7 +670,11 @@ def _field_for_generic_type( type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = field_for_schema( + arguments[0], + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ) list_type = cast( Type[marshmallow.fields.List], type_mapping.get(List, marshmallow.fields.List), @@ -680,25 +687,42 @@ def _field_for_generic_type( ): from . import collection_field - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = field_for_schema( + arguments[0], + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = field_for_schema( + arguments[0], + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) if origin in (frozenset, FrozenSet): from . import collection_field - child_type = _field_for_schema(arguments[0], base_schema=base_schema) + child_type = field_for_schema( + arguments[0], + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - _field_for_schema(arg, base_schema=base_schema) for arg in arguments + field_for_schema( + arg, + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ) + for arg in arguments ) tuple_type = cast( Type[marshmallow.fields.Tuple], @@ -710,8 +734,16 @@ def _field_for_generic_type( if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=_field_for_schema(arguments[0], base_schema=base_schema), - values=_field_for_schema(arguments[1], base_schema=base_schema), + keys=field_for_schema( + arguments[0], + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ), + values=field_for_schema( + arguments[1], + base_schema=base_schema, + generic_params_to_args=generic_params_to_args, + ), **metadata, ) @@ -768,6 +800,7 @@ def _field_for_union_type( subtypes[0], metadata=metadata, base_schema=base_schema, + generic_params_to_args=generic_params_to_args, ) from . import union_field @@ -779,6 +812,7 @@ def _field_for_union_type( subtyp, metadata={"required": True}, base_schema=base_schema, + generic_params_to_args=generic_params_to_args, ), ) for subtyp in subtypes @@ -818,7 +852,9 @@ def field_for_schema( """ with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None): - return _field_for_schema(typ, default, metadata, base_schema) + return _field_for_schema( + typ, default, metadata, base_schema, generic_params_to_args + ) def _field_for_schema( @@ -826,6 +862,7 @@ def _field_for_schema( default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, + generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -913,7 +950,9 @@ def _field_for_schema( ) else: subtyp = Any - return _field_for_schema(subtyp, default, metadata, base_schema) + return field_for_schema( + subtyp, default, metadata, base_schema, generic_params_to_args + ) annotated_field = _field_for_annotated_type(typ, **metadata) if annotated_field: @@ -938,6 +977,7 @@ def _field_for_schema( newtype_supertype=newtype_supertype, metadata=metadata, base_schema=base_schema, + generic_params_to_args=generic_params_to_args, ) # enumerations diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index c5b9537..60319d3 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -1,7 +1,7 @@ import inspect import typing import unittest -from typing import Any, cast, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from uuid import UUID try: @@ -10,11 +10,14 @@ from typing_extensions import Final, Literal # type: ignore[assignment] import dataclasses + from marshmallow import Schema, ValidationError -from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer +from marshmallow.fields import UUID as UUIDField +from marshmallow.fields import Field, Integer +from marshmallow.fields import List as ListField from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass +from marshmallow_dataclass import NewType, _is_generic_alias_of_dataclass, class_schema class TestClassSchema(unittest.TestCase): @@ -465,24 +468,70 @@ class SimpleGeneric(typing.Generic[T]): data: T @dataclasses.dataclass - class Nested: + class NestedFixed: data: SimpleGeneric[int] + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + schema_s = class_schema(SimpleGeneric[str])() self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) with self.assertRaises(ValidationError): schema_s.load({"data": 2}) - schema_n = class_schema(Nested)() + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() self.assertEqual( - Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}}) + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), ) self.assertEqual( - schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}} + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, ) with self.assertRaises(ValidationError): - schema_n.load({"data": {"data": "str"}}) + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + y: BB[AA] + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) if __name__ == "__main__": From f2734ccbc8319727039102a933b6165c1e9d028d Mon Sep 17 00:00:00 2001 From: Onur Satici Date: Thu, 15 Dec 2022 15:28:57 +1100 Subject: [PATCH 04/25] support py3.6 --- marshmallow_dataclass/__init__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 4aa811e..8a7c96e 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -1044,8 +1044,13 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: Check if given class is a generic alias of a dataclass, if the dataclass is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ - return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass( - typing_inspect.get_origin(clazz) + is_generic = typing_inspect.is_generic_type(clazz) + type_arguments = typing_inspect.get_args(clazz) + origin_class = typing_inspect.get_origin(clazz) + return ( + is_generic + and len(type_arguments) > 0 + and dataclasses.is_dataclass(origin_class) ) From 9a5dc091e3622fdfd74c06fa1aad6e0793a8dc08 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:30:47 +0100 Subject: [PATCH 05/25] Split generic tests into it's own file. --- tests/test_class_schema.py | 75 +-------------------------------- tests/test_generics.py | 86 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 74 deletions(-) create mode 100644 tests/test_generics.py diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 60319d3..8a1beb7 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -17,7 +17,7 @@ from marshmallow.fields import List as ListField from marshmallow.validate import Validator -from marshmallow_dataclass import NewType, _is_generic_alias_of_dataclass, class_schema +from marshmallow_dataclass import NewType, class_schema class TestClassSchema(unittest.TestCase): @@ -460,79 +460,6 @@ class Meta: self.assertNotIn("no_init", class_schema(NoInit)().fields) self.assertIn("no_init", class_schema(Init)().fields) - def test_generic_dataclass(self): - T = typing.TypeVar("T") - - @dataclasses.dataclass - class SimpleGeneric(typing.Generic[T]): - data: T - - @dataclasses.dataclass - class NestedFixed: - data: SimpleGeneric[int] - - @dataclasses.dataclass - class NestedGeneric(typing.Generic[T]): - data: SimpleGeneric[T] - - self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) - self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) - - schema_s = class_schema(SimpleGeneric[str])() - self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) - self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) - with self.assertRaises(ValidationError): - schema_s.load({"data": 2}) - - schema_nested = class_schema(NestedFixed)() - self.assertEqual( - NestedFixed(data=SimpleGeneric(1)), - schema_nested.load({"data": {"data": 1}}), - ) - self.assertEqual( - schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), - {"data": {"data": 1}}, - ) - with self.assertRaises(ValidationError): - schema_nested.load({"data": {"data": "str"}}) - - schema_nested_generic = class_schema(NestedGeneric[int])() - self.assertEqual( - NestedGeneric(data=SimpleGeneric(1)), - schema_nested_generic.load({"data": {"data": 1}}), - ) - self.assertEqual( - schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), - {"data": {"data": 1}}, - ) - with self.assertRaises(ValidationError): - schema_nested_generic.load({"data": {"data": "str"}}) - - def test_generic_dataclass_repeated_fields(self): - T = typing.TypeVar("T") - - @dataclasses.dataclass - class AA: - a: int - - @dataclasses.dataclass - class BB(typing.Generic[T]): - b: T - - @dataclasses.dataclass - class Nested: - x: BB[float] - z: BB[float] - # if y is the first field in this class, deserialisation will fail. - # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 - y: BB[AA] - - schema_nested = class_schema(Nested)() - self.assertEqual( - Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), - schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), - ) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..0904a14 --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,86 @@ +import dataclasses +import typing +import unittest + +from marshmallow import ValidationError + +from marshmallow_dataclass import _is_generic_alias_of_dataclass, class_schema + + +class TestGenerics(unittest.TestCase): + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + y: BB[AA] + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + + +if __name__ == "__main__": + unittest.main() From 844333630d741f974ae9e1cff035b9216864fd41 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Tue, 12 Mar 2024 00:05:46 +0100 Subject: [PATCH 06/25] Add support for deep generics with swapped TypeVars. * Raise descriptive error for unbound fields. --- marshmallow_dataclass/__init__.py | 309 ++++++++++++++++++++++-------- tests/test_generics.py | 131 ++++++++++++- 2 files changed, 359 insertions(+), 81 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 8a7c96e..1be01b2 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -36,6 +36,7 @@ class User: """ import collections.abc +import copy import dataclasses import inspect import sys @@ -133,6 +134,63 @@ def _maybe_get_callers_frame( del frame +class UnboundTypeVarError(TypeError): + """TypeVar instance can not be resolved to a type spec. + + This exception is raised when an unbound TypeVar is encountered. + + """ + + +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + + def __init__(self) -> None: + self._done = False + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + +def _check_decorated_type(cls: object) -> None: + if not isinstance(cls, type): + raise TypeError(f"expected a class not {cls!r}") + if _is_generic_alias(cls): + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliasses " + "(hint: use class_schema directly instead)" + ) + + @overload def dataclass( _cls: Type[_U], @@ -206,12 +264,18 @@ def dataclass( ) def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + if cls is not None: + _check_decorated_type(cls) + return add_schema( dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 ) if _cls is None: return decorator + + if _cls is not None: + _check_decorated_type(_cls) return decorator(_cls, stacklevel=stacklevel + 1) @@ -257,6 +321,8 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): """ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: + _check_decorated_type(clazz) + if cls_frame is not None: frame = cls_frame else: @@ -450,7 +516,7 @@ def class_schema( if clazz_frame is not None: localns = clazz_frame.f_locals with _SchemaContext(globalns, localns): - return _internal_class_schema(clazz, base_schema, None) + return _internal_class_schema(clazz, base_schema) class _SchemaContext: @@ -496,17 +562,10 @@ def top(self) -> _U: _schema_ctx_stack = _LocalStack[_SchemaContext]() -def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: - if _is_generic_alias_of_dataclass(clazz): - clazz = typing_inspect.get_origin(clazz) - return dataclasses.fields(clazz) - - @lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE) def _internal_class_schema( clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None, - generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> Type[marshmallow.Schema]: schema_ctx = _schema_ctx_stack.top @@ -520,6 +579,8 @@ def _internal_class_schema( try: fields = _dataclass_fields(clazz) + except UnboundTypeVarError: + raise except TypeError: # Not a dataclass try: warnings.warn( @@ -534,9 +595,9 @@ def _internal_class_schema( "****** WARNING ******" ) created_dataclass: type = dataclasses.dataclass(clazz) - return _internal_class_schema( - created_dataclass, base_schema, generic_params_to_args - ) + return _internal_class_schema(created_dataclass, base_schema) + except UnboundTypeVarError: + raise except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -552,39 +613,21 @@ def _internal_class_schema( # Determine whether we should include non-init fields include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False) - if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None: - generic_params_to_args = _generic_params_to_args(clazz) - - type_hints = _dataclass_type_hints(clazz, schema_ctx, generic_params_to_args) - # Update the schema members to contain marshmallow fields instead of dataclass fields - - if sys.version_info >= (3, 9): - type_hints = get_type_hints( - clazz, - globalns=schema_ctx.globalns, - localns=schema_ctx.localns, - include_extras=True, - ) - else: - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) attributes.update( ( field.name, - _field_for_schema( - type_hints[field.name], + field_for_schema( + _get_field_type_hints(field, schema_ctx), _get_field_default(field), field.metadata, base_schema, - generic_params_to_args, ), ) for field in fields if field.init or include_non_init ) - schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -602,7 +645,6 @@ def _field_by_supertype( newtype_supertype: Type, metadata: dict, base_schema: Optional[Type[marshmallow.Schema]], - generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Return a new field for fields based on a super field. (Usually spawned from NewType) @@ -633,7 +675,6 @@ def _field_by_supertype( metadata=metadata, default=default, base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) @@ -657,7 +698,6 @@ def _generic_type_add_any(typ: type) -> type: def _field_for_generic_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], - generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ @@ -673,7 +713,6 @@ def _field_for_generic_type( child_type = field_for_schema( arguments[0], base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) list_type = cast( Type[marshmallow.fields.List], @@ -690,7 +729,6 @@ def _field_for_generic_type( child_type = field_for_schema( arguments[0], base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): @@ -699,7 +737,6 @@ def _field_for_generic_type( child_type = field_for_schema( arguments[0], base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata @@ -710,7 +747,6 @@ def _field_for_generic_type( child_type = field_for_schema( arguments[0], base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata @@ -720,7 +756,6 @@ def _field_for_generic_type( field_for_schema( arg, base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) for arg in arguments ) @@ -737,12 +772,10 @@ def _field_for_generic_type( keys=field_for_schema( arguments[0], base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ), values=field_for_schema( arguments[1], base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ), **metadata, ) @@ -800,7 +833,6 @@ def _field_for_union_type( subtypes[0], metadata=metadata, base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) from . import union_field @@ -812,7 +844,6 @@ def _field_for_union_type( subtyp, metadata={"required": True}, base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ), ) for subtyp in subtypes @@ -829,7 +860,6 @@ def field_for_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, # FIXME: delete typ_frame from API? typ_frame: Optional[types.FrameType] = None, - generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -852,9 +882,7 @@ def field_for_schema( """ with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None): - return _field_for_schema( - typ, default, metadata, base_schema, generic_params_to_args - ) + return _field_for_schema(typ, default, metadata, base_schema) def _field_for_schema( @@ -862,7 +890,6 @@ def _field_for_schema( default: Any = marshmallow.missing, metadata: Optional[Mapping[str, Any]] = None, base_schema: Optional[Type[marshmallow.Schema]] = None, - generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, ) -> marshmallow.fields.Field: """ Get a marshmallow Field corresponding to the given python type. @@ -878,6 +905,9 @@ def _field_for_schema( """ + if isinstance(typ, TypeVar): + raise UnboundTypeVarError(f"can not resolve type variable {typ.__name__}") + metadata = {} if metadata is None else dict(metadata) if default is not marshmallow.missing: @@ -950,9 +980,7 @@ def _field_for_schema( ) else: subtyp = Any - return field_for_schema( - subtyp, default, metadata, base_schema, generic_params_to_args - ) + return field_for_schema(subtyp, default, metadata, base_schema) annotated_field = _field_for_annotated_type(typ, **metadata) if annotated_field: @@ -977,7 +1005,6 @@ def _field_for_schema( newtype_supertype=newtype_supertype, metadata=metadata, base_schema=base_schema, - generic_params_to_args=generic_params_to_args, ) # enumerations @@ -995,7 +1022,7 @@ def _field_for_schema( nested_schema or forward_reference or _schema_ctx_stack.top.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema, generic_params_to_args) # type: ignore [arg-type] + or _internal_class_schema(typ, base_schema) # type: ignore [arg-type] ) return marshmallow.fields.Nested(nested, **metadata) @@ -1044,7 +1071,7 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: Check if given class is a generic alias of a dataclass, if the dataclass is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ - is_generic = typing_inspect.is_generic_type(clazz) + is_generic = is_generic_type(clazz) type_arguments = typing_inspect.get_args(clazz) origin_class = typing_inspect.get_origin(clazz) return ( @@ -1054,38 +1081,160 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: ) -def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]: - base_dataclass = typing_inspect.get_origin(clazz) - base_parameters = typing_inspect.get_parameters(base_dataclass) +def _get_field_type_hints( + field: dataclasses.Field, + schema_ctx: Optional[_SchemaContext] = None, +) -> type: + """typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works.""" + + class X: + x: field.type # type: ignore[name-defined] + + if sys.version_info >= (3, 9): + type_hints = get_type_hints( + X, + globalns=schema_ctx.globalns, + localns=schema_ctx.localns, + include_extras=True, + )["x"] + else: + type_hints = get_type_hints( + X, globalns=schema_ctx.globalns, localns=schema_ctx.localns + )["x"] + + return type_hints + + +def _is_generic_alias(clazz: type) -> bool: + """ + Check if given class is a generic alias of a class is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + is_generic = is_generic_type(clazz) type_arguments = typing_inspect.get_args(clazz) - return tuple(zip(base_parameters, type_arguments)) + return is_generic and len(type_arguments) > 0 -def _dataclass_type_hints( - clazz: type, - schema_ctx: _SchemaContext = None, - generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None, -) -> Mapping[str, type]: - if not _is_generic_alias_of_dataclass(clazz): - return get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns +def is_generic_type(clazz: type) -> bool: + """ + typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar + """ + return ( + isinstance(clazz, type) + and issubclass(clazz, Generic) # type: ignore[arg-type] + or isinstance(clazz, typing_inspect.typingGenericAlias) + ) + + +def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: + """ + Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics. + + Returns a dict of each base class and the resolved generics. + """ + # Use Tuples so can zip (order matters) + args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {} + parent_class: Optional[type] = None + # Loop in reversed order and iteratively resolve types + for subclass in reversed(clazz.mro()): + if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type] + args = typing_inspect.get_args(subclass.__orig_bases__[0]) + + if parent_class and args_by_class.get(parent_class): + subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = [] + for (_arg, future), potential_type in zip( + args_by_class[parent_class], args + ): + if isinstance(potential_type, TypeVar): + subclass_generic_params_to_args.append((potential_type, future)) + else: + future.set_result(potential_type) + + args_by_class[subclass] = tuple(subclass_generic_params_to_args) + + else: + args_by_class[subclass] = tuple((arg, _Future()) for arg in args) + + parent_class = subclass + + # clazz itself is a generic alias i.e.: A[int]. So it hold the last types. + if _is_generic_alias(clazz): + origin = typing_inspect.get_origin(clazz) + args = typing_inspect.get_args(clazz) + for (_arg, future), potential_type in zip(args_by_class[origin], args): + if not isinstance(potential_type, TypeVar): + future.set_result(potential_type) + + # Convert to nested dict for easier lookup + return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()} + + +def _replace_typevars( + clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None +) -> type: + if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz): + return clazz + + return clazz.copy_with( # type: ignore + tuple( + ( + _replace_typevars(arg, resolved_generics) + if is_generic_type(arg) + else ( + resolved_generics[arg].result() if arg in resolved_generics else arg + ) + ) + for arg in typing_inspect.get_args(clazz) ) - # dataclass is generic - generic_type_hints = get_type_hints( - typing_inspect.get_origin(clazz), - globalns=schema_ctx.globalns, - localns=schema_ctx.localns, ) - generic_params_map = dict(generic_params_to_args if generic_params_to_args else {}) - def _get_hint(_t: type) -> type: - if isinstance(_t, TypeVar): - return generic_params_map[_t] - return _t - return { - field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items() - } +def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + if not is_generic_type(clazz): + return dataclasses.fields(clazz) + + else: + unbound_fields = set() + # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and + # looses the source class. Thus I don't know how to resolve this at later on. + # Instead we recreate the type but with all known TypeVars resolved to their actual types. + resolved_typevars = _resolve_typevars(clazz) + # Dict[field_name, Tuple[original_field, resolved_field]] + fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {} + + for subclass in reversed(clazz.mro()): + if not dataclasses.is_dataclass(subclass): + continue + + for field in dataclasses.fields(subclass): + try: + if field.name in fields and fields[field.name][0] == field: + continue # identical, so already resolved. + + # Either the first time we see this field, or it got overridden + # If it's a class we handle it later as a Nested. Nothing to resolve now. + new_field = field + if not inspect.isclass(field.type) and is_generic_type(field.type): + new_field = copy.copy(field) + new_field.type = _replace_typevars( + field.type, resolved_typevars[subclass] + ) + elif isinstance(field.type, TypeVar): + new_field = copy.copy(field) + new_field.type = resolved_typevars[subclass][ + field.type + ].result() + + fields[field.name] = (field, new_field) + except InvalidStateError: + unbound_fields.add(field.name) + + if unbound_fields: + raise UnboundTypeVarError( + f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}" + ) + + return tuple(v[1] for v in fields.values()) def NewType( diff --git a/tests/test_generics.py b/tests/test_generics.py index 0904a14..3e65fc2 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -4,7 +4,12 @@ from marshmallow import ValidationError -from marshmallow_dataclass import _is_generic_alias_of_dataclass, class_schema +from marshmallow_dataclass import ( + UnboundTypeVarError, + _is_generic_alias_of_dataclass, + add_schema, + class_schema, +) class TestGenerics(unittest.TestCase): @@ -81,6 +86,130 @@ class Nested: schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), ) + def test_marshmallow_dataclass_decorator_raises_on_generic_alias(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + import marshmallow_dataclass + + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass[int]) + + def test_add_schema_raises_on_generic_alias(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass[int]) + + def test_deep_generic(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: typing.List[typing.Tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + + def test_deep_generic_with_overrides(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + V = typing.TypeVar("V") + W = typing.TypeVar("W") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U, V]): + pairs: typing.List[typing.Tuple[T, U]] + gen: V + override: int + + # Don't only override typevar, but switch order to further confuse things + @dataclasses.dataclass + class TestClass2(TestClass[str, W, U]): + override: str # type: ignore # Want to test that it works, even if incompatible types + + TestAlias = TestClass2[int, T] + + # inherit from alias + @dataclasses.dataclass + class TestClass3(TestAlias[typing.List[int]]): + pass + + test_schema = class_schema(TestClass3)() + + self.assertEqual( + test_schema.load( + {"pairs": [("first", "1")], "gen": ["1", 2], "override": "overridden"} + ), + TestClass3([("first", 1)], [1, 2], "overridden"), + ) + + def test_generic_bases(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[T]): + pass + + test_schema = class_schema(TestClass[int])() + + self.assertEqual(test_schema.load({"answer": "1"}), TestClass(1)) + + def test_bound_generic_base(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[int]): + pass + + with self.assertRaisesRegex( + UnboundTypeVarError, "Base1 has unbound fields: answer" + ): + class_schema(Base1) + + test_schema = class_schema(TestClass)() + self.assertEqual(test_schema.load({"answer": "1"}), TestClass(1)) + + def test_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + if __name__ == "__main__": unittest.main() From 8a0f837e0486939ead38344a700f672e73c00cda Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Tue, 25 Jun 2024 00:51:43 +0200 Subject: [PATCH 07/25] Fix tests after rebase --- marshmallow_dataclass/__init__.py | 82 ++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 1be01b2..ffda08d 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -202,7 +202,8 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, -) -> Type[_U]: ... +) -> Type[_U]: + ... @overload @@ -215,7 +216,8 @@ def dataclass( frozen: bool = False, base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, -) -> Callable[[Type[_U]], Type[_U]]: ... +) -> Callable[[Type[_U]], Type[_U]]: + ... # _cls should never be specified by keyword, so start it with an @@ -280,13 +282,15 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: @overload -def add_schema(_cls: Type[_U]) -> Type[_U]: ... +def add_schema(_cls: Type[_U]) -> Type[_U]: + ... @overload def add_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, -) -> Callable[[Type[_U]], Type[_U]]: ... +) -> Callable[[Type[_U]], Type[_U]]: + ... @overload @@ -295,7 +299,8 @@ def add_schema( base_schema: Optional[Type[marshmallow.Schema]] = None, cls_frame: Optional[types.FrameType] = None, stacklevel: int = 1, -) -> Type[_U]: ... +) -> Type[_U]: + ... def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): @@ -348,7 +353,8 @@ def class_schema( *, globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, -) -> Type[marshmallow.Schema]: ... +) -> Type[marshmallow.Schema]: + ... @overload @@ -358,7 +364,8 @@ def class_schema( clazz_frame: Optional[types.FrameType] = None, *, globalns: Optional[Dict[str, Any]] = None, -) -> Type[marshmallow.Schema]: ... +) -> Type[marshmallow.Schema]: + ... def class_schema( @@ -573,7 +580,8 @@ def _internal_class_schema( # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977 class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined] else: - class_name = clazz.__name__ + # generic aliases do not have a __name__ prior python 3.10 + class_name = getattr(clazz, "__name__", repr(clazz)) schema_ctx.seen_classes[clazz] = class_name @@ -613,11 +621,20 @@ def _internal_class_schema( # Determine whether we should include non-init fields include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False) + # Update the schema members to contain marshmallow fields instead of dataclass fields + type_hints = {} + if not is_generic_type(clazz): + type_hints = _get_type_hints(clazz, schema_ctx) + attributes.update( ( field.name, - field_for_schema( - _get_field_type_hints(field, schema_ctx), + _field_for_schema( + ( + type_hints[field.name] + if not is_generic_type(clazz) + else _get_generic_type_hints(field.type, schema_ctx) + ), _get_field_default(field), field.metadata, base_schema, @@ -710,7 +727,7 @@ def _field_for_generic_type( type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = field_for_schema( + child_type = _field_for_schema( arguments[0], base_schema=base_schema, ) @@ -726,7 +743,7 @@ def _field_for_generic_type( ): from . import collection_field - child_type = field_for_schema( + child_type = _field_for_schema( arguments[0], base_schema=base_schema, ) @@ -734,7 +751,7 @@ def _field_for_generic_type( if origin in (set, Set): from . import collection_field - child_type = field_for_schema( + child_type = _field_for_schema( arguments[0], base_schema=base_schema, ) @@ -744,7 +761,7 @@ def _field_for_generic_type( if origin in (frozenset, FrozenSet): from . import collection_field - child_type = field_for_schema( + child_type = _field_for_schema( arguments[0], base_schema=base_schema, ) @@ -753,7 +770,7 @@ def _field_for_generic_type( ) if origin in (tuple, Tuple): children = tuple( - field_for_schema( + _field_for_schema( arg, base_schema=base_schema, ) @@ -980,7 +997,7 @@ def _field_for_schema( ) else: subtyp = Any - return field_for_schema(subtyp, default, metadata, base_schema) + return _field_for_schema(subtyp, default, metadata, base_schema) annotated_field = _field_for_annotated_type(typ, **metadata) if annotated_field: @@ -1081,30 +1098,37 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool: ) -def _get_field_type_hints( - field: dataclasses.Field, - schema_ctx: Optional[_SchemaContext] = None, -) -> type: - """typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works.""" - - class X: - x: field.type # type: ignore[name-defined] - +def _get_type_hints( + obj, + schema_ctx: _SchemaContext, +): if sys.version_info >= (3, 9): type_hints = get_type_hints( - X, + obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns, include_extras=True, - )["x"] + ) else: type_hints = get_type_hints( - X, globalns=schema_ctx.globalns, localns=schema_ctx.localns - )["x"] + obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) return type_hints +def _get_generic_type_hints( + obj, + schema_ctx: _SchemaContext, +) -> type: + """typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works.""" + + class X: + x: obj # type: ignore[name-defined] + + return _get_type_hints(X, schema_ctx)["x"] + + def _is_generic_alias(clazz: type) -> bool: """ Check if given class is a generic alias of a class is From f484596416294df7aa5ecfd305c3367d25ab0717 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Tue, 25 Jun 2024 01:11:03 +0200 Subject: [PATCH 08/25] Remove unnecessary whitespace --- marshmallow_dataclass/__init__.py | 36 ++++++------------------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index ffda08d..f4fbef5 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -727,10 +727,7 @@ def _field_for_generic_type( type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): - child_type = _field_for_schema( - arguments[0], - base_schema=base_schema, - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) list_type = cast( Type[marshmallow.fields.List], type_mapping.get(List, marshmallow.fields.List), @@ -743,38 +740,25 @@ def _field_for_generic_type( ): from . import collection_field - child_type = _field_for_schema( - arguments[0], - base_schema=base_schema, - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Sequence(cls_or_instance=child_type, **metadata) if origin in (set, Set): from . import collection_field - child_type = _field_for_schema( - arguments[0], - base_schema=base_schema, - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=False, **metadata ) if origin in (frozenset, FrozenSet): from . import collection_field - child_type = _field_for_schema( - arguments[0], - base_schema=base_schema, - ) + child_type = _field_for_schema(arguments[0], base_schema=base_schema) return collection_field.Set( cls_or_instance=child_type, frozen=True, **metadata ) if origin in (tuple, Tuple): children = tuple( - _field_for_schema( - arg, - base_schema=base_schema, - ) - for arg in arguments + _field_for_schema(arg, base_schema=base_schema) for arg in arguments ) tuple_type = cast( Type[marshmallow.fields.Tuple], @@ -786,14 +770,8 @@ def _field_for_generic_type( if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=field_for_schema( - arguments[0], - base_schema=base_schema, - ), - values=field_for_schema( - arguments[1], - base_schema=base_schema, - ), + keys=field_for_schema(arguments[0], base_schema=base_schema), + values=field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) From 80dab91998672e897ce9057886661690bbe59594 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Tue, 25 Jun 2024 01:12:34 +0200 Subject: [PATCH 09/25] fix call correct _field_for_schema function --- marshmallow_dataclass/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index f4fbef5..938b3a8 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -770,8 +770,8 @@ def _field_for_generic_type( if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( - keys=field_for_schema(arguments[0], base_schema=base_schema), - values=field_for_schema(arguments[1], base_schema=base_schema), + keys=_field_for_schema(arguments[0], base_schema=base_schema), + values=_field_for_schema(arguments[1], base_schema=base_schema), **metadata, ) From 4531c35cf5f1b50e1dbc793395377a380aae4bf6 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Tue, 25 Jun 2024 13:26:45 +0200 Subject: [PATCH 10/25] Break generic functions out into it's own file and add support for annotated generics, partials, and callables --- marshmallow_dataclass/__init__.py | 222 ++++------------------ marshmallow_dataclass/generic_resolver.py | 193 +++++++++++++++++++ tests/test_annotated.py | 65 ++++++- tests/test_generics.py | 97 +++++++++- 4 files changed, 392 insertions(+), 185 deletions(-) create mode 100644 marshmallow_dataclass/generic_resolver.py diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 938b3a8..3b4fdab 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -36,7 +36,6 @@ class User: """ import collections.abc -import copy import dataclasses import inspect import sys @@ -64,6 +63,12 @@ class User: import typing_extensions import typing_inspect +from marshmallow_dataclass.generic_resolver import ( + UnboundTypeVarError, + get_generic_dataclass_fields, + is_generic_alias, + is_generic_type, +) from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute if sys.version_info >= (3, 9): @@ -134,55 +139,10 @@ def _maybe_get_callers_frame( del frame -class UnboundTypeVarError(TypeError): - """TypeVar instance can not be resolved to a type spec. - - This exception is raised when an unbound TypeVar is encountered. - - """ - - -class InvalidStateError(Exception): - """Raised when an operation is performed on a future that is not - allowed in the current state. - """ - - -class _Future(Generic[_U]): - """The _Future class allows deferred access to a result that is not - yet available. - """ - - _done: bool - _result: _U - - def __init__(self) -> None: - self._done = False - - def done(self) -> bool: - """Return ``True`` if the value is available""" - return self._done - - def result(self) -> _U: - """Return the deferred value. - - Raises ``InvalidStateError`` if the value has not been set. - """ - if self.done(): - return self._result - raise InvalidStateError("result has not been set") - - def set_result(self, result: _U) -> None: - if self.done(): - raise InvalidStateError("result has already been set") - self._result = result - self._done = True - - def _check_decorated_type(cls: object) -> None: if not isinstance(cls, type): raise TypeError(f"expected a class not {cls!r}") - if _is_generic_alias(cls): + if is_generic_alias(cls): # A .Schema attribute doesn't make sense on a generic alias — there's # no way for it to know the generic parameters at run time. raise TypeError( @@ -513,9 +473,7 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass( - clazz - ): + if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz): clazz = dataclasses.dataclass(clazz) if localns is None: if clazz_frame is None: @@ -791,8 +749,16 @@ def _field_for_annotated_type( marshmallow_annotations = [ arg for arg in arguments[1:] - if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field)) - or isinstance(arg, marshmallow.fields.Field) + if _is_marshmallow_field(arg) + # Support `CustomGenericField[mf.String]` + or ( + is_generic_type(arg) + and _is_marshmallow_field(typing_extensions.get_origin(arg)) + ) + # Support `partial(mf.List, mf.String)` + or (isinstance(arg, partial) and _is_marshmallow_field(arg.func)) + # Support `lambda *args, **kwargs: mf.List(mf.String, *args, **kwargs)` + or (_is_callable_marshmallow_field(arg)) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: @@ -932,7 +898,7 @@ def _field_for_schema( # i.e.: Literal['abc'] if typing_inspect.is_literal_type(typ): - arguments = typing_inspect.get_args(typ) + arguments = typing_extensions.get_args(typ) return marshmallow.fields.Raw( validate=( marshmallow.validate.Equal(arguments[0]) @@ -944,7 +910,7 @@ def _field_for_schema( # i.e.: Final[str] = 'abc' if typing_inspect.is_final_type(typ): - arguments = typing_inspect.get_args(typ) + arguments = typing_extensions.get_args(typ) if arguments: subtyp = arguments[0] elif default is not marshmallow.missing: @@ -1061,14 +1027,14 @@ def _get_field_default(field: dataclasses.Field): return field.default -def _is_generic_alias_of_dataclass(clazz: type) -> bool: +def is_generic_alias_of_dataclass(clazz: type) -> bool: """ Check if given class is a generic alias of a dataclass, if the dataclass is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ is_generic = is_generic_type(clazz) - type_arguments = typing_inspect.get_args(clazz) - origin_class = typing_inspect.get_origin(clazz) + type_arguments = typing_extensions.get_args(clazz) + origin_class = typing_extensions.get_origin(clazz) return ( is_generic and len(type_arguments) > 0 @@ -1107,136 +1073,30 @@ class X: return _get_type_hints(X, schema_ctx)["x"] -def _is_generic_alias(clazz: type) -> bool: - """ - Check if given class is a generic alias of a class is - defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed - """ - is_generic = is_generic_type(clazz) - type_arguments = typing_inspect.get_args(clazz) - return is_generic and len(type_arguments) > 0 - - -def is_generic_type(clazz: type) -> bool: - """ - typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar - """ - return ( - isinstance(clazz, type) - and issubclass(clazz, Generic) # type: ignore[arg-type] - or isinstance(clazz, typing_inspect.typingGenericAlias) - ) - - -def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: - """ - Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics. - - Returns a dict of each base class and the resolved generics. - """ - # Use Tuples so can zip (order matters) - args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {} - parent_class: Optional[type] = None - # Loop in reversed order and iteratively resolve types - for subclass in reversed(clazz.mro()): - if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type] - args = typing_inspect.get_args(subclass.__orig_bases__[0]) - - if parent_class and args_by_class.get(parent_class): - subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = [] - for (_arg, future), potential_type in zip( - args_by_class[parent_class], args - ): - if isinstance(potential_type, TypeVar): - subclass_generic_params_to_args.append((potential_type, future)) - else: - future.set_result(potential_type) - - args_by_class[subclass] = tuple(subclass_generic_params_to_args) - - else: - args_by_class[subclass] = tuple((arg, _Future()) for arg in args) - - parent_class = subclass - - # clazz itself is a generic alias i.e.: A[int]. So it hold the last types. - if _is_generic_alias(clazz): - origin = typing_inspect.get_origin(clazz) - args = typing_inspect.get_args(clazz) - for (_arg, future), potential_type in zip(args_by_class[origin], args): - if not isinstance(potential_type, TypeVar): - future.set_result(potential_type) - - # Convert to nested dict for easier lookup - return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()} - - -def _replace_typevars( - clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None -) -> type: - if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz): - return clazz - - return clazz.copy_with( # type: ignore - tuple( - ( - _replace_typevars(arg, resolved_generics) - if is_generic_type(arg) - else ( - resolved_generics[arg].result() if arg in resolved_generics else arg - ) - ) - for arg in typing_inspect.get_args(clazz) - ) - ) - - def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: if not is_generic_type(clazz): return dataclasses.fields(clazz) else: - unbound_fields = set() - # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and - # looses the source class. Thus I don't know how to resolve this at later on. - # Instead we recreate the type but with all known TypeVars resolved to their actual types. - resolved_typevars = _resolve_typevars(clazz) - # Dict[field_name, Tuple[original_field, resolved_field]] - fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {} - - for subclass in reversed(clazz.mro()): - if not dataclasses.is_dataclass(subclass): - continue - - for field in dataclasses.fields(subclass): - try: - if field.name in fields and fields[field.name][0] == field: - continue # identical, so already resolved. - - # Either the first time we see this field, or it got overridden - # If it's a class we handle it later as a Nested. Nothing to resolve now. - new_field = field - if not inspect.isclass(field.type) and is_generic_type(field.type): - new_field = copy.copy(field) - new_field.type = _replace_typevars( - field.type, resolved_typevars[subclass] - ) - elif isinstance(field.type, TypeVar): - new_field = copy.copy(field) - new_field.type = resolved_typevars[subclass][ - field.type - ].result() - - fields[field.name] = (field, new_field) - except InvalidStateError: - unbound_fields.add(field.name) - - if unbound_fields: - raise UnboundTypeVarError( - f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}" - ) + return get_generic_dataclass_fields(clazz) + + +def _is_marshmallow_field(obj) -> bool: + return ( + inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field) + ) or isinstance(obj, marshmallow.fields.Field) + + +def _is_callable_marshmallow_field(obj) -> bool: + """Checks if the object is a callable and if the callable returns a marshmallow field""" + if callable(obj): + try: + potential_field = obj() + return _is_marshmallow_field(potential_field) + except Exception: + return False - return tuple(v[1] for v in fields.values()) + return False def NewType( diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py new file mode 100644 index 0000000..56287ab --- /dev/null +++ b/marshmallow_dataclass/generic_resolver.py @@ -0,0 +1,193 @@ +import copy +import dataclasses +import inspect +import sys +from typing import ( + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, +) + +import typing_extensions +import typing_inspect + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + +_U = TypeVar("_U") + + +class UnboundTypeVarError(TypeError): + """TypeVar instance can not be resolved to a type spec. + + This exception is raised when an unbound TypeVar is encountered. + + """ + + +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + + def __init__(self) -> None: + self._done = False + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + +def is_generic_alias(clazz: type) -> bool: + """ + Check if given class is a generic alias of a class is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + is_generic = is_generic_type(clazz) + type_arguments = typing_extensions.get_args(clazz) + return is_generic and len(type_arguments) > 0 + + +def is_generic_type(clazz: type) -> bool: + """ + typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar + """ + origin = typing_extensions.get_origin(clazz) + return origin is not Annotated and ( + (isinstance(clazz, type) and issubclass(clazz, Generic)) # type: ignore[arg-type] + or isinstance(clazz, typing_inspect.typingGenericAlias) + ) + + +def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: + """ + Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics. + + Returns a dict of each base class and the resolved generics. + """ + # Use Tuples so can zip (order matters) + args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {} + parent_class: Optional[type] = None + # Loop in reversed order and iteratively resolve types + for subclass in reversed(clazz.mro()): + if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type] + args = typing_extensions.get_args(subclass.__orig_bases__[0]) + + if parent_class and args_by_class.get(parent_class): + subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = [] + for (_arg, future), potential_type in zip( + args_by_class[parent_class], args + ): + if isinstance(potential_type, TypeVar): + subclass_generic_params_to_args.append((potential_type, future)) + else: + future.set_result(potential_type) + + args_by_class[subclass] = tuple(subclass_generic_params_to_args) + + else: + args_by_class[subclass] = tuple((arg, _Future()) for arg in args) + + parent_class = subclass + + # clazz itself is a generic alias i.e.: A[int]. So it hold the last types. + if is_generic_alias(clazz): + origin = typing_extensions.get_origin(clazz) + args = typing_extensions.get_args(clazz) + for (_arg, future), potential_type in zip(args_by_class[origin], args): # type: ignore[index] + if not isinstance(potential_type, TypeVar): + future.set_result(potential_type) + + # Convert to nested dict for easier lookup + return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()} + + +def _replace_typevars( + clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None +) -> type: + if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz): + return clazz + + return clazz.copy_with( # type: ignore + tuple( + ( + _replace_typevars(arg, resolved_generics) + if is_generic_type(arg) + else ( + resolved_generics[arg].result() if arg in resolved_generics else arg + ) + ) + for arg in typing_extensions.get_args(clazz) + ) + ) + + +def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + unbound_fields = set() + # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and + # looses the source class. Thus I don't know how to resolve this at later on. + # Instead we recreate the type but with all known TypeVars resolved to their actual types. + resolved_typevars = _resolve_typevars(clazz) + # Dict[field_name, Tuple[original_field, resolved_field]] + fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {} + + for subclass in reversed(clazz.mro()): + if not dataclasses.is_dataclass(subclass): + continue + + for field in dataclasses.fields(subclass): + try: + if field.name in fields and fields[field.name][0] == field: + continue # identical, so already resolved. + + # Either the first time we see this field, or it got overridden + # If it's a class we handle it later as a Nested. Nothing to resolve now. + new_field = field + if not inspect.isclass(field.type) and is_generic_type(field.type): + new_field = copy.copy(field) + new_field.type = _replace_typevars( + field.type, resolved_typevars[subclass] + ) + elif isinstance(field.type, TypeVar): + new_field = copy.copy(field) + new_field.type = resolved_typevars[subclass][field.type].result() + + fields[field.name] = (field, new_field) + except InvalidStateError: + unbound_fields.add(field.name) + + if unbound_fields: + raise UnboundTypeVarError( + f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}" + ) + + return tuple(v[1] for v in fields.values()) diff --git a/tests/test_annotated.py b/tests/test_annotated.py index e9105a6..b0f6ae1 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -1,6 +1,8 @@ +import dataclasses +import functools import sys import unittest -from typing import Optional +from typing import List, Optional import marshmallow import marshmallow.fields @@ -35,3 +37,64 @@ class AnnotatedValue: with self.assertRaises(marshmallow.exceptions.ValidationError): schema.load({"value": "notavalidemail"}) + + def test_annotated_partial_field(self) -> None: + """ + NewType allowed us to specify a lambda or partial because there was no type inspection. + But with Annotated we do type inspection. Partial still allows us to to type inspection. + """ + + @dataclass + class AnnotatedValue: + emails: Annotated[ + List[str], + functools.partial(marshmallow.fields.List, marshmallow.fields.Email), + ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) + + schema = AnnotatedValue.Schema() # type: ignore[attr-defined] + + self.assertEqual( + schema.load({}), + AnnotatedValue(emails=["default@email.com"]), + ) + self.assertEqual( + schema.load({"emails": ["test@test.com"]}), + AnnotatedValue( + emails=["test@test.com"], + ), + ) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"emails": "notavalidemail"}) + + def test_annotated_callable_field(self) -> None: + """ + NewType allowed us to specify a lambda or partial because there was no type inspection. + But with Annotated we do type inspection. While we can't reliably do type inspection on a callable, + i.e.: lambda, we can call it and check if it returns a Field. + """ + + @dataclass + class AnnotatedValue: + emails: Annotated[ + List[str], + lambda *args, **kwargs: marshmallow.fields.List( + marshmallow.fields.Email, *args, **kwargs + ), + ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) + + schema = AnnotatedValue.Schema() # type: ignore[attr-defined] + + self.assertEqual( + schema.load({}), + AnnotatedValue(emails=["default@email.com"]), + ) + self.assertEqual( + schema.load({"emails": ["test@test.com"]}), + AnnotatedValue( + emails=["test@test.com"], + ), + ) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"emails": "notavalidemail"}) diff --git a/tests/test_generics.py b/tests/test_generics.py index 3e65fc2..1aa4eda 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,15 +1,67 @@ import dataclasses +import inspect +import sys import typing import unittest +import marshmallow.fields from marshmallow import ValidationError from marshmallow_dataclass import ( UnboundTypeVarError, - _is_generic_alias_of_dataclass, add_schema, class_schema, + dataclass, + is_generic_alias_of_dataclass, ) +from marshmallow_dataclass.generic_resolver import is_generic_type + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +def get_orig_class(obj): + """ + Allows you got get the runtime origin class inside __init__ + + Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182 + """ + try: + # See https://github.com/Stewori/pytypes/pull/53: + # Returns `obj.__orig_class__` protecting from infinite recursion in `__getattr[ibute]__` + # wrapped in a `checker_tp`. + # (See `checker_tp` in `typechecker._typeinspect_func for context) + # Necessary if: + # - we're wrapping a method (`obj` is `self`/`cls`) and either + # - the object's class defines __getattribute__ + # or + # - the object doesn't have an `__orig_class__` attribute + # and the object's class defines __getattr__. + # In such a situation, `parent_class = obj.__orig_class__` + # would call `__getattr[ibute]__`. But that method is wrapped in a `checker_tp` too, + # so then we'd go into the wrapped `__getattr[ibute]__` and do + # `parent_class = obj.__orig_class__`, which would call `__getattr[ibute]__` + # again, and so on. So to bypass `__getattr[ibute]__` we do this: + return object.__getattribute__(obj, "__orig_class__") + except AttributeError: + cls = object.__getattribute__(obj, "__class__") + if is_generic_type(cls): + # Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller. + frame = inspect.currentframe().f_back + try: + while frame: + try: + res = frame.f_locals["self"] + if res.__origin__ is cls: + return res + except (KeyError, AttributeError): + frame = frame.f_back + finally: + del frame + + raise class TestGenerics(unittest.TestCase): @@ -28,8 +80,8 @@ class NestedFixed: class NestedGeneric(typing.Generic[T]): data: SimpleGeneric[T] - self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int])) - self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric)) + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) schema_s = class_schema(SimpleGeneric[str])() self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) @@ -210,6 +262,45 @@ class Base: with self.assertRaises(TypeError): class_schema(Base) + def test_annotated_generic_mf_field(self) -> None: + T = typing.TypeVar("T") + + class GenericList(marshmallow.fields.List, typing.Generic[T]): + """ + Generic Marshmallow List Field that can be used in Annotated and still get all kwargs + from marshmallow_dataclass. + """ + + def __init__( + self, + **kwargs, + ): + cls_or_instance = get_orig_class(self).__args__[0] + + super().__init__(cls_or_instance, **kwargs) + + @dataclass + class AnnotatedValue: + emails: Annotated[ + typing.List[str], GenericList[marshmallow.fields.Email] + ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) + + schema = AnnotatedValue.Schema() # type: ignore[attr-defined] + + self.assertEqual( + schema.load({}), + AnnotatedValue(emails=["default@email.com"]), + ) + self.assertEqual( + schema.load({"emails": ["test@test.com"]}), + AnnotatedValue( + emails=["test@test.com"], + ), + ) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"emails": "notavalidemail"}) + if __name__ == "__main__": unittest.main() From dd34efc958b110046a9de5815c91f459498563cd Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:34:59 +0200 Subject: [PATCH 11/25] Remove support for callable annotations This approach was unsafe. See PR #259 for more details --- marshmallow_dataclass/__init__.py | 14 -------------- tests/test_annotated.py | 32 ------------------------------- 2 files changed, 46 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 3b4fdab..06664d8 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -757,8 +757,6 @@ def _field_for_annotated_type( ) # Support `partial(mf.List, mf.String)` or (isinstance(arg, partial) and _is_marshmallow_field(arg.func)) - # Support `lambda *args, **kwargs: mf.List(mf.String, *args, **kwargs)` - or (_is_callable_marshmallow_field(arg)) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: @@ -1087,18 +1085,6 @@ def _is_marshmallow_field(obj) -> bool: ) or isinstance(obj, marshmallow.fields.Field) -def _is_callable_marshmallow_field(obj) -> bool: - """Checks if the object is a callable and if the callable returns a marshmallow field""" - if callable(obj): - try: - potential_field = obj() - return _is_marshmallow_field(potential_field) - except Exception: - return False - - return False - - def NewType( name: str, typ: Type[_U], diff --git a/tests/test_annotated.py b/tests/test_annotated.py index b0f6ae1..1e386f6 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -66,35 +66,3 @@ class AnnotatedValue: with self.assertRaises(marshmallow.exceptions.ValidationError): schema.load({"emails": "notavalidemail"}) - - def test_annotated_callable_field(self) -> None: - """ - NewType allowed us to specify a lambda or partial because there was no type inspection. - But with Annotated we do type inspection. While we can't reliably do type inspection on a callable, - i.e.: lambda, we can call it and check if it returns a Field. - """ - - @dataclass - class AnnotatedValue: - emails: Annotated[ - List[str], - lambda *args, **kwargs: marshmallow.fields.List( - marshmallow.fields.Email, *args, **kwargs - ), - ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) - - schema = AnnotatedValue.Schema() # type: ignore[attr-defined] - - self.assertEqual( - schema.load({}), - AnnotatedValue(emails=["default@email.com"]), - ) - self.assertEqual( - schema.load({"emails": ["test@test.com"]}), - AnnotatedValue( - emails=["test@test.com"], - ), - ) - - with self.assertRaises(marshmallow.exceptions.ValidationError): - schema.load({"emails": "notavalidemail"}) From 7ac088d6086c5e4e2313bced99f1454204834563 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Wed, 26 Jun 2024 23:40:48 +0200 Subject: [PATCH 12/25] Remove support for annotated partials --- marshmallow_dataclass/__init__.py | 2 -- tests/test_annotated.py | 33 +------------------------------ 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 06664d8..c6e88d4 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -755,8 +755,6 @@ def _field_for_annotated_type( is_generic_type(arg) and _is_marshmallow_field(typing_extensions.get_origin(arg)) ) - # Support `partial(mf.List, mf.String)` - or (isinstance(arg, partial) and _is_marshmallow_field(arg.func)) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: diff --git a/tests/test_annotated.py b/tests/test_annotated.py index 1e386f6..e9105a6 100644 --- a/tests/test_annotated.py +++ b/tests/test_annotated.py @@ -1,8 +1,6 @@ -import dataclasses -import functools import sys import unittest -from typing import List, Optional +from typing import Optional import marshmallow import marshmallow.fields @@ -37,32 +35,3 @@ class AnnotatedValue: with self.assertRaises(marshmallow.exceptions.ValidationError): schema.load({"value": "notavalidemail"}) - - def test_annotated_partial_field(self) -> None: - """ - NewType allowed us to specify a lambda or partial because there was no type inspection. - But with Annotated we do type inspection. Partial still allows us to to type inspection. - """ - - @dataclass - class AnnotatedValue: - emails: Annotated[ - List[str], - functools.partial(marshmallow.fields.List, marshmallow.fields.Email), - ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) - - schema = AnnotatedValue.Schema() # type: ignore[attr-defined] - - self.assertEqual( - schema.load({}), - AnnotatedValue(emails=["default@email.com"]), - ) - self.assertEqual( - schema.load({"emails": ["test@test.com"]}), - AnnotatedValue( - emails=["test@test.com"], - ), - ) - - with self.assertRaises(marshmallow.exceptions.ValidationError): - schema.load({"emails": "notavalidemail"}) From b3362bab89449127ec1f02f649f4013bedefc291 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Thu, 27 Jun 2024 22:31:17 +0200 Subject: [PATCH 13/25] Fix import style and some docstrings, and reuse is_generic_alias instead of duplicating logic --- marshmallow_dataclass/__init__.py | 23 ++++++++++++----------- marshmallow_dataclass/generic_resolver.py | 6 +++--- tests/test_generics.py | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index c6e88d4..5f2fab0 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -44,9 +44,15 @@ class User: import warnings from enum import Enum from functools import lru_cache, partial -from typing import Any, Callable, Dict, FrozenSet, Generic, List, Mapping -from typing import NewType as typing_NewType from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Generic, + List, + Mapping, + NewType as typing_NewType, Optional, Sequence, Set, @@ -146,7 +152,7 @@ def _check_decorated_type(cls: object) -> None: # A .Schema attribute doesn't make sense on a generic alias — there's # no way for it to know the generic parameters at run time. raise TypeError( - "decorator does not support generic aliasses " + "decorator does not support generic aliases " "(hint: use class_schema directly instead)" ) @@ -1028,13 +1034,8 @@ def is_generic_alias_of_dataclass(clazz: type) -> bool: Check if given class is a generic alias of a dataclass, if the dataclass is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ - is_generic = is_generic_type(clazz) - type_arguments = typing_extensions.get_args(clazz) - origin_class = typing_extensions.get_origin(clazz) - return ( - is_generic - and len(type_arguments) > 0 - and dataclasses.is_dataclass(origin_class) + return is_generic_alias(clazz) and dataclasses.is_dataclass( + typing_extensions.get_origin(clazz) ) @@ -1061,7 +1062,7 @@ def _get_generic_type_hints( obj, schema_ctx: _SchemaContext, ) -> type: - """typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works.""" + """typing.get_type_hints doesn't work with generic aliases. But this 'hack' works.""" class X: x: obj # type: ignore[name-defined] diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index 56287ab..6276f31 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -69,8 +69,8 @@ def set_result(self, result: _U) -> None: def is_generic_alias(clazz: type) -> bool: """ - Check if given class is a generic alias of a class is - defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + Check if given class is a generic alias of a class. + If a class is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ is_generic = is_generic_type(clazz) type_arguments = typing_extensions.get_args(clazz) @@ -79,7 +79,7 @@ def is_generic_alias(clazz: type) -> bool: def is_generic_type(clazz: type) -> bool: """ - typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar + typing_inspect.is_generic_type explicitly ignores Union and Tuple """ origin = typing_extensions.get_origin(clazz) return origin is not Annotated and ( diff --git a/tests/test_generics.py b/tests/test_generics.py index 1aa4eda..8b222ee 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -182,6 +182,20 @@ class TestClass(typing.Generic[T, U]): test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) ) + def test_deep_generic_with_union(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + either: typing.List[typing.Union[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"either": ["first", 1]}), TestClass(["first", 1]) + ) + def test_deep_generic_with_overrides(self): T = typing.TypeVar("T") U = typing.TypeVar("U") From db95e6428c14fc81a0d3092268bf67f8098ffb7a Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Thu, 27 Jun 2024 22:54:19 +0200 Subject: [PATCH 14/25] Rename function to be more descriptive --- marshmallow_dataclass/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 5f2fab0..5dc65a3 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -597,7 +597,7 @@ def _internal_class_schema( ( type_hints[field.name] if not is_generic_type(clazz) - else _get_generic_type_hints(field.type, schema_ctx) + else _get_type_hint_of_generic_object(field.type, schema_ctx) ), _get_field_default(field), field.metadata, @@ -1058,11 +1058,11 @@ def _get_type_hints( return type_hints -def _get_generic_type_hints( +def _get_type_hint_of_generic_object( obj, schema_ctx: _SchemaContext, ) -> type: - """typing.get_type_hints doesn't work with generic aliases. But this 'hack' works.""" + """typing.get_type_hints doesn't work with generic aliases, i.e.: A[int]. But this 'hack' works.""" class X: x: obj # type: ignore[name-defined] From a494984faa938efccde123815c0ca52da59aafa2 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Thu, 27 Jun 2024 22:57:57 +0200 Subject: [PATCH 15/25] Improved doc string --- marshmallow_dataclass/generic_resolver.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index 6276f31..033db4b 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -69,8 +69,16 @@ def set_result(self, result: _U) -> None: def is_generic_alias(clazz: type) -> bool: """ - Check if given class is a generic alias of a class. - If a class is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + Check if given object is a Generic Alias. + + A `generic alias`__ is a generic type bound to generic parameters. + + E.g., given + + class A(Generic[T]): + pass + + ``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*). """ is_generic = is_generic_type(clazz) type_arguments = typing_extensions.get_args(clazz) From 78fcd4aa1b5ab5d2cb0f806b16f7353e2f111480 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Fri, 28 Jun 2024 00:05:51 +0200 Subject: [PATCH 16/25] Clean up * Differentiate between Generics and container type functions * Tie get_args and get_origin functions to Annotated import. * Rename function and add test to clarify forward ref use case --- marshmallow_dataclass/__init__.py | 62 +++++++++++------------ marshmallow_dataclass/generic_resolver.py | 17 +++---- tests/test_generics.py | 13 +++++ 3 files changed, 51 insertions(+), 41 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 5dc65a3..9a1b427 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -66,7 +66,6 @@ class User: ) import marshmallow -import typing_extensions import typing_inspect from marshmallow_dataclass.generic_resolver import ( @@ -78,9 +77,9 @@ class User: from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute if sys.version_info >= (3, 9): - from typing import Annotated + from typing import Annotated, get_args, get_origin else: - from typing_extensions import Annotated + from typing_extensions import Annotated, get_args, get_origin if sys.version_info >= (3, 11): from typing import dataclass_transform @@ -540,7 +539,7 @@ def _internal_class_schema( ) -> Type[marshmallow.Schema]: schema_ctx = _schema_ctx_stack.top - if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10): + if get_origin(clazz) is Annotated and sys.version_info < (3, 10): # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977 class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined] else: @@ -597,7 +596,7 @@ def _internal_class_schema( ( type_hints[field.name] if not is_generic_type(clazz) - else _get_type_hint_of_generic_object(field.type, schema_ctx) + else _resolve_forward_type_refs(field.type, schema_ctx) ), _get_field_default(field), field.metadata, @@ -659,8 +658,8 @@ def _field_by_supertype( ) -def _generic_type_add_any(typ: type) -> type: - """if typ is generic type without arguments, replace them by Any.""" +def _container_type_add_any(typ: type) -> type: + """if typ is container type without arguments, replace them by Any.""" if typ is list or typ is List: typ = List[Any] elif typ is dict or typ is Dict: @@ -676,18 +675,20 @@ def _generic_type_add_any(typ: type) -> type: return typ -def _field_for_generic_type( +def _field_for_container_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ - If the type is a generic interface, resolve the arguments and construct the appropriate Field. + If the type is a container interface, resolve the arguments and construct the appropriate Field. + + We use the term 'container' to differentiate from the Generic support """ - origin = typing_extensions.get_origin(typ) - arguments = typing_extensions.get_args(typ) + origin = get_origin(typ) + arguments = get_args(typ) if origin: - # Override base_schema.TYPE_MAPPING to change the class used for generic types below + # Override base_schema.TYPE_MAPPING to change the class used for container types below type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): @@ -749,18 +750,15 @@ def _field_for_annotated_type( """ If the type is an Annotated interface, resolve the arguments and construct the appropriate Field. """ - origin = typing_extensions.get_origin(typ) - arguments = typing_extensions.get_args(typ) + origin = get_origin(typ) + arguments = get_args(typ) if origin and origin is Annotated: marshmallow_annotations = [ arg for arg in arguments[1:] if _is_marshmallow_field(arg) # Support `CustomGenericField[mf.String]` - or ( - is_generic_type(arg) - and _is_marshmallow_field(typing_extensions.get_origin(arg)) - ) + or (is_generic_type(arg) and _is_marshmallow_field(get_origin(arg))) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: @@ -782,7 +780,7 @@ def _field_for_union_type( base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: - arguments = typing_extensions.get_args(typ) + arguments = get_args(typ) if typing_inspect.is_union_type(typ): if typing_inspect.is_optional_type(typ): metadata["allow_none"] = metadata.get("allow_none", True) @@ -886,8 +884,8 @@ def _field_for_schema( if predefined_field: return predefined_field - # Generic types specified without type arguments - typ = _generic_type_add_any(typ) + # Container types (generics like List) specified without type arguments + typ = _container_type_add_any(typ) # Base types field = _field_by_type(typ, base_schema) @@ -900,7 +898,7 @@ def _field_for_schema( # i.e.: Literal['abc'] if typing_inspect.is_literal_type(typ): - arguments = typing_extensions.get_args(typ) + arguments = get_args(typ) return marshmallow.fields.Raw( validate=( marshmallow.validate.Equal(arguments[0]) @@ -912,7 +910,7 @@ def _field_for_schema( # i.e.: Final[str] = 'abc' if typing_inspect.is_final_type(typ): - arguments = typing_extensions.get_args(typ) + arguments = get_args(typ) if arguments: subtyp = arguments[0] elif default is not marshmallow.missing: @@ -953,10 +951,10 @@ def _field_for_schema( if union_field: return union_field - # Generic types - generic_field = _field_for_generic_type(typ, base_schema, **metadata) - if generic_field: - return generic_field + # Container types + container_field = _field_for_container_type(typ, base_schema, **metadata) + if container_field: + return container_field # typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a # __supertype__ attribute @@ -1034,9 +1032,7 @@ def is_generic_alias_of_dataclass(clazz: type) -> bool: Check if given class is a generic alias of a dataclass, if the dataclass is defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed """ - return is_generic_alias(clazz) and dataclasses.is_dataclass( - typing_extensions.get_origin(clazz) - ) + return is_generic_alias(clazz) and dataclasses.is_dataclass(get_origin(clazz)) def _get_type_hints( @@ -1058,11 +1054,13 @@ def _get_type_hints( return type_hints -def _get_type_hint_of_generic_object( +def _resolve_forward_type_refs( obj, schema_ctx: _SchemaContext, ) -> type: - """typing.get_type_hints doesn't work with generic aliases, i.e.: A[int]. But this 'hack' works.""" + """ + Resolve forward references, mainly applies to Generics i.e.: `A["int"]` -> `A[int]` + """ class X: x: obj # type: ignore[name-defined] diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index 033db4b..3d9e403 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -11,13 +11,12 @@ TypeVar, ) -import typing_extensions import typing_inspect if sys.version_info >= (3, 9): - from typing import Annotated + from typing import Annotated, get_args, get_origin else: - from typing_extensions import Annotated + from typing_extensions import Annotated, get_args, get_origin _U = TypeVar("_U") @@ -81,7 +80,7 @@ class A(Generic[T]): ``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*). """ is_generic = is_generic_type(clazz) - type_arguments = typing_extensions.get_args(clazz) + type_arguments = get_args(clazz) return is_generic and len(type_arguments) > 0 @@ -89,7 +88,7 @@ def is_generic_type(clazz: type) -> bool: """ typing_inspect.is_generic_type explicitly ignores Union and Tuple """ - origin = typing_extensions.get_origin(clazz) + origin = get_origin(clazz) return origin is not Annotated and ( (isinstance(clazz, type) and issubclass(clazz, Generic)) # type: ignore[arg-type] or isinstance(clazz, typing_inspect.typingGenericAlias) @@ -108,7 +107,7 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: # Loop in reversed order and iteratively resolve types for subclass in reversed(clazz.mro()): if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type] - args = typing_extensions.get_args(subclass.__orig_bases__[0]) + args = get_args(subclass.__orig_bases__[0]) if parent_class and args_by_class.get(parent_class): subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = [] @@ -129,8 +128,8 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: # clazz itself is a generic alias i.e.: A[int]. So it hold the last types. if is_generic_alias(clazz): - origin = typing_extensions.get_origin(clazz) - args = typing_extensions.get_args(clazz) + origin = get_origin(clazz) + args = get_args(clazz) for (_arg, future), potential_type in zip(args_by_class[origin], args): # type: ignore[index] if not isinstance(potential_type, TypeVar): future.set_result(potential_type) @@ -154,7 +153,7 @@ def _replace_typevars( resolved_generics[arg].result() if arg in resolved_generics else arg ) ) - for arg in typing_extensions.get_args(clazz) + for arg in get_args(clazz) ) ) diff --git a/tests/test_generics.py b/tests/test_generics.py index 8b222ee..268cd7e 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -315,6 +315,19 @@ class AnnotatedValue: with self.assertRaises(marshmallow.exceptions.ValidationError): schema.load({"emails": "notavalidemail"}) + def test_generic_dataclass_with_forwardref(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + schema_s = class_schema(SimpleGeneric["str"])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + if __name__ == "__main__": unittest.main() From 8797b2b29c1b11f0a28354f0d5443dad5a50758d Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Fri, 28 Jun 2024 00:35:01 +0200 Subject: [PATCH 17/25] Rename our `is_generic_type` and only utilize where absolutely necessary --- marshmallow_dataclass/__init__.py | 12 +++++----- marshmallow_dataclass/generic_resolver.py | 20 ++++++++++++----- tests/test_generics.py | 27 ++++++++++++++++++----- 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 9a1b427..9ea09bf 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -72,7 +72,6 @@ class User: UnboundTypeVarError, get_generic_dataclass_fields, is_generic_alias, - is_generic_type, ) from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute @@ -586,7 +585,7 @@ def _internal_class_schema( # Update the schema members to contain marshmallow fields instead of dataclass fields type_hints = {} - if not is_generic_type(clazz): + if not typing_inspect.is_generic_type(clazz): type_hints = _get_type_hints(clazz, schema_ctx) attributes.update( @@ -595,7 +594,7 @@ def _internal_class_schema( _field_for_schema( ( type_hints[field.name] - if not is_generic_type(clazz) + if not typing_inspect.is_generic_type(clazz) else _resolve_forward_type_refs(field.type, schema_ctx) ), _get_field_default(field), @@ -758,7 +757,10 @@ def _field_for_annotated_type( for arg in arguments[1:] if _is_marshmallow_field(arg) # Support `CustomGenericField[mf.String]` - or (is_generic_type(arg) and _is_marshmallow_field(get_origin(arg))) + or ( + typing_inspect.is_generic_type(arg) + and _is_marshmallow_field(get_origin(arg)) + ) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: @@ -1069,7 +1071,7 @@ class X: def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: - if not is_generic_type(clazz): + if not typing_inspect.is_generic_type(clazz): return dataclasses.fields(clazz) else: diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index 3d9e403..cf48568 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -79,14 +79,18 @@ class A(Generic[T]): ``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*). """ - is_generic = is_generic_type(clazz) + is_generic = typing_inspect.is_generic_type(clazz) type_arguments = get_args(clazz) return is_generic and len(type_arguments) > 0 -def is_generic_type(clazz: type) -> bool: +def may_contain_typevars(clazz: type) -> bool: """ - typing_inspect.is_generic_type explicitly ignores Union and Tuple + Check if the class can contain typevars. This includes Special Forms. + + Different from typing_inspect.is_generic_type as that explicitly ignores Union and Tuple. + + We still need to resolve typevars for Union and Tuple """ origin = get_origin(clazz) return origin is not Annotated and ( @@ -141,14 +145,18 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: def _replace_typevars( clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None ) -> type: - if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz): + if ( + not resolved_generics + or inspect.isclass(clazz) + or not may_contain_typevars(clazz) + ): return clazz return clazz.copy_with( # type: ignore tuple( ( _replace_typevars(arg, resolved_generics) - if is_generic_type(arg) + if may_contain_typevars(arg) else ( resolved_generics[arg].result() if arg in resolved_generics else arg ) @@ -179,7 +187,7 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: # Either the first time we see this field, or it got overridden # If it's a class we handle it later as a Nested. Nothing to resolve now. new_field = field - if not inspect.isclass(field.type) and is_generic_type(field.type): + if not inspect.isclass(field.type) and may_contain_typevars(field.type): new_field = copy.copy(field) new_field.type = _replace_typevars( field.type, resolved_typevars[subclass] diff --git a/tests/test_generics.py b/tests/test_generics.py index 268cd7e..ee853b7 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -3,6 +3,7 @@ import sys import typing import unittest +from typing_inspect import is_generic_type import marshmallow.fields from marshmallow import ValidationError @@ -14,7 +15,6 @@ dataclass, is_generic_alias_of_dataclass, ) -from marshmallow_dataclass.generic_resolver import is_generic_type if sys.version_info >= (3, 9): from typing import Annotated @@ -319,12 +319,29 @@ def test_generic_dataclass_with_forwardref(self): T = typing.TypeVar("T") @dataclasses.dataclass - class SimpleGeneric(typing.Generic[T]): + class ForwardGeneric(typing.Generic[T]): data: T - schema_s = class_schema(SimpleGeneric["str"])() - self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) - self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + schema_s = class_schema(ForwardGeneric["str"])() + self.assertEqual(ForwardGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(ForwardGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + def test_generic_dataclass_with_optional(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class OptionalGeneric(typing.Generic[T]): + data: typing.Optional[T] + + schema_s = class_schema(OptionalGeneric["str"])() + self.assertEqual(OptionalGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(OptionalGeneric(data="a")), {"data": "a"}) + + self.assertEqual(OptionalGeneric(data=None), schema_s.load({})) + self.assertEqual(schema_s.dump(OptionalGeneric(data=None)), {"data": None}) + with self.assertRaises(ValidationError): schema_s.load({"data": 2}) From c361f3ac706b89954a420b3d70ed96589113cf71 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sat, 28 Sep 2024 12:20:19 +0200 Subject: [PATCH 18/25] Clean up unnessary if statements and redundant function call --- marshmallow_dataclass/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 9ea09bf..6aacd11 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -230,8 +230,7 @@ def dataclass( ) def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: - if cls is not None: - _check_decorated_type(cls) + _check_decorated_type(cls) return add_schema( dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 @@ -240,8 +239,6 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: if _cls is None: return decorator - if _cls is not None: - _check_decorated_type(_cls) return decorator(_cls, stacklevel=stacklevel + 1) From c3f5da1c557b82554a2eaa1a50873b54e998c79f Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sun, 29 Sep 2024 17:48:25 +0200 Subject: [PATCH 19/25] Add support for TypeVar defaults Tested with 3.12.0rc2 --- marshmallow_dataclass/generic_resolver.py | 52 ++++++- tests/test_generics.py | 160 ++++++++++++++++++++++ 2 files changed, 210 insertions(+), 2 deletions(-) diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index cf48568..9854029 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -18,6 +18,16 @@ else: from typing_extensions import Annotated, get_args, get_origin +if sys.version_info >= (3, 13): + from typing import NoDefault +else: + from typing import final + + @final + class NoDefault: + pass + + _U = TypeVar("_U") @@ -25,7 +35,14 @@ class UnboundTypeVarError(TypeError): """TypeVar instance can not be resolved to a type spec. This exception is raised when an unbound TypeVar is encountered. + """ + +class InvalidTypeVarDefaultError(TypeError): + """TypeVar default can not be resolved to a type spec. + + This exception is raised when an invalid TypeVar default is encountered. + This is most likely a scoping error: https://peps.python.org/pep-0696/#scoping-rules """ @@ -42,9 +59,11 @@ class _Future(Generic[_U]): _done: bool _result: _U + _default: _U | "_Future[_U]" - def __init__(self) -> None: + def __init__(self, default=NoDefault) -> None: self._done = False + self._default = default def done(self) -> bool: """Return ``True`` if the value is available""" @@ -57,6 +76,13 @@ def result(self) -> _U: """ if self.done(): return self._result + + if self._default is not NoDefault: + if isinstance(self._default, _Future): + return self._default.result() + + return self._default + raise InvalidStateError("result has not been set") def set_result(self, result: _U) -> None: @@ -120,13 +146,35 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: ): if isinstance(potential_type, TypeVar): subclass_generic_params_to_args.append((potential_type, future)) + default = getattr(potential_type, "__default__", NoDefault) + if default is not None: + future._default = default else: future.set_result(potential_type) args_by_class[subclass] = tuple(subclass_generic_params_to_args) else: - args_by_class[subclass] = tuple((arg, _Future()) for arg in args) + # PEP-696: Typevar's may be used as defaults, but T1 must be used before T2 + # https://peps.python.org/pep-0696/#scoping-rules + seen_type_args: Dict[TypeVar, _Future] = {} + for arg in args: + default = getattr(arg, "__default__", NoDefault) + if default is not None: + if isinstance(default, TypeVar): + if default in seen_type_args: + # We've already seen this TypeVar, Set the default to it's _Future + default = seen_type_args[default] + + else: + # We haven't seen this yet, according to PEP-696 this is invalid. + raise InvalidTypeVarDefaultError( + f"{subclass.__name__} has an invalid TypeVar default for field {arg}" + ) + + seen_type_args[arg] = _Future(default=default) + + args_by_class[subclass] = tuple(seen_type_args.items()) parent_class = subclass diff --git a/tests/test_generics.py b/tests/test_generics.py index ee853b7..3d2e11e 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -6,6 +6,7 @@ from typing_inspect import is_generic_type import marshmallow.fields +import pytest from marshmallow import ValidationError from marshmallow_dataclass import ( @@ -345,6 +346,165 @@ class OptionalGeneric(typing.Generic[T]): with self.assertRaises(ValidationError): schema_s.load({"data": 2}) + @pytest.mark.skipif( + sys.version_info <= (3, 13), reason="requires python 3.13 or higher" + ) + def test_generic_default(self): + T = typing.TypeVar("T", default=str) + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric)() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + @pytest.mark.skipif( + sys.version_info <= (3, 13), reason="requires python 3.13 or higher" + ) + def test_deep_generic_with_default_overrides(self): + T = typing.TypeVar("T", default=bool) + U = typing.TypeVar("U", default=int) + V = typing.TypeVar("V", default=str) + W = typing.TypeVar("W", default=float) + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U, V]): + pairs: typing.List[typing.Tuple[T, U]] + gen: V + override: int + + test_schema = class_schema(TestClass)() + assert list(test_schema.fields) == ["pairs", "gen", "override"] + assert isinstance(test_schema.fields["pairs"], marshmallow.fields.List) + assert isinstance(test_schema.fields["pairs"].inner, marshmallow.fields.Tuple) + assert isinstance( + test_schema.fields["pairs"].inner.tuple_fields[0], + marshmallow.fields.Boolean, + ) + assert isinstance( + test_schema.fields["pairs"].inner.tuple_fields[1], + marshmallow.fields.Integer, + ) + + assert isinstance(test_schema.fields["gen"], marshmallow.fields.String) + assert isinstance(test_schema.fields["override"], marshmallow.fields.Integer) + + # Don't only override typevar, but switch order to further confuse things + @dataclasses.dataclass + class TestClass2(TestClass[str, W, U]): + override: str # type: ignore # Want to test that it works, even if incompatible types + + TestAlias = TestClass2[int, T] + test_schema2 = class_schema(TestClass2)() + assert list(test_schema2.fields) == ["pairs", "gen", "override"] + assert isinstance(test_schema2.fields["pairs"], marshmallow.fields.List) + assert isinstance(test_schema2.fields["pairs"].inner, marshmallow.fields.Tuple) + assert isinstance( + test_schema2.fields["pairs"].inner.tuple_fields[0], + marshmallow.fields.String, + ) + assert isinstance( + test_schema2.fields["pairs"].inner.tuple_fields[1], + marshmallow.fields.Float, + ) + + assert isinstance(test_schema2.fields["gen"], marshmallow.fields.Integer) + assert isinstance(test_schema2.fields["override"], marshmallow.fields.String) + + # inherit from alias + @dataclasses.dataclass + class TestClass3(TestAlias[typing.List[int]]): + pass + + test_schema3 = class_schema(TestClass3)() + assert list(test_schema3.fields) == ["pairs", "gen", "override"] + assert isinstance(test_schema3.fields["pairs"], marshmallow.fields.List) + assert isinstance(test_schema3.fields["pairs"].inner, marshmallow.fields.Tuple) + assert isinstance( + test_schema3.fields["pairs"].inner.tuple_fields[0], + marshmallow.fields.String, + ) + assert isinstance( + test_schema3.fields["pairs"].inner.tuple_fields[1], + marshmallow.fields.Integer, + ) + + assert isinstance(test_schema3.fields["gen"], marshmallow.fields.List) + assert isinstance(test_schema3.fields["gen"].inner, marshmallow.fields.Integer) + assert isinstance(test_schema3.fields["override"], marshmallow.fields.String) + + self.assertEqual( + test_schema3.load( + {"pairs": [("first", "1")], "gen": ["1", 2], "override": "overridden"} + ), + TestClass3([("first", 1)], [1, 2], "overridden"), + ) + + @pytest.mark.skipif( + sys.version_info <= (3, 13), reason="requires python 3.13 or higher" + ) + def test_generic_default_recursion(self): + T = typing.TypeVar("T", default=str) + U = typing.TypeVar("U", default=T) + V = typing.TypeVar("V", default=U) + + @dataclasses.dataclass + class DefaultGenerics(typing.Generic[T, U, V]): + a: T + b: U + c: V + + test_schema = class_schema(DefaultGenerics)() + assert list(test_schema.fields) == ["a", "b", "c"] + assert isinstance(test_schema.fields["a"], marshmallow.fields.String) + assert isinstance(test_schema.fields["b"], marshmallow.fields.String) + assert isinstance(test_schema.fields["c"], marshmallow.fields.String) + + test_schema2 = class_schema(DefaultGenerics[int])() + assert list(test_schema2.fields) == ["a", "b", "c"] + assert isinstance(test_schema2.fields["a"], marshmallow.fields.Integer) + assert isinstance(test_schema2.fields["b"], marshmallow.fields.Integer) + assert isinstance(test_schema2.fields["c"], marshmallow.fields.Integer) + if __name__ == "__main__": unittest.main() From 231b3b205a89d78ec040b587d41388729357aeee Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sun, 29 Sep 2024 17:57:02 +0200 Subject: [PATCH 20/25] fix: Use Union compatible with <3.10 --- marshmallow_dataclass/generic_resolver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index 9854029..a365104 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -9,6 +9,7 @@ Optional, Tuple, TypeVar, + Union, ) import typing_inspect @@ -59,7 +60,7 @@ class _Future(Generic[_U]): _done: bool _result: _U - _default: _U | "_Future[_U]" + _default: Union[_U, "_Future[_U]"] def __init__(self, default=NoDefault) -> None: self._done = False From 2ef5a71ba4840d4207ee24dc8a59b1c0bd418412 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sat, 30 Nov 2024 12:41:00 +0100 Subject: [PATCH 21/25] Add python 3.13 --- .github/workflows/python-package.yml | 2 +- .pre-commit-config.yaml | 2 +- marshmallow_dataclass/generic_resolver.py | 9 ++++--- setup.py | 1 + tests/test_generics.py | 18 +++++++------ tox.ini | 31 +++++++++++++++++++++++ 6 files changed, 49 insertions(+), 14 deletions(-) create mode 100644 tox.ini diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0837335..bb7224b 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,7 +12,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest"] - python_version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.10"] + python_version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "pypy3.10"] runs-on: ${{ matrix.os }} steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1f1ea6..787e909 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - id: flake8 additional_dependencies: ['flake8-bugbear==22.10.27'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: [typeguard,marshmallow] diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index a365104..3a1e556 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -236,14 +236,15 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: # Either the first time we see this field, or it got overridden # If it's a class we handle it later as a Nested. Nothing to resolve now. new_field = field - if not inspect.isclass(field.type) and may_contain_typevars(field.type): + field_type: type = field.type # type: ignore[assignment] + if not inspect.isclass(field_type) and may_contain_typevars(field_type): new_field = copy.copy(field) new_field.type = _replace_typevars( - field.type, resolved_typevars[subclass] + field_type, resolved_typevars[subclass] ) - elif isinstance(field.type, TypeVar): + elif isinstance(field_type, TypeVar): new_field = copy.copy(field) - new_field.type = resolved_typevars[subclass][field.type].result() + new_field.type = resolved_typevars[subclass][field_type].result() fields[field.name] = (field, new_field) except InvalidStateError: diff --git a/setup.py b/setup.py index ceb2555..d6afb80 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", ] diff --git a/tests/test_generics.py b/tests/test_generics.py index 3d2e11e..8a9f2cc 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -210,15 +210,16 @@ class TestClass(typing.Generic[T, U, V]): override: int # Don't only override typevar, but switch order to further confuse things + # Ignoring 'override' Because I want to test that it works, even if incompatible types @dataclasses.dataclass - class TestClass2(TestClass[str, W, U]): - override: str # type: ignore # Want to test that it works, even if incompatible types + class TestClass2(TestClass[str, W, U]): # type: ignore[override] + override: str # type: ignore[override, assignment] - TestAlias = TestClass2[int, T] + TestAlias = TestClass2[int, T] # type: ignore[override] # inherit from alias @dataclasses.dataclass - class TestClass3(TestAlias[typing.List[int]]): + class TestClass3(TestAlias[typing.List[int]]): # type: ignore[override] pass test_schema = class_schema(TestClass3)() @@ -430,10 +431,11 @@ class TestClass(typing.Generic[T, U, V]): # Don't only override typevar, but switch order to further confuse things @dataclasses.dataclass - class TestClass2(TestClass[str, W, U]): - override: str # type: ignore # Want to test that it works, even if incompatible types + class TestClass2(TestClass[str, W, U]): # type: ignore[override] + # Want to test that it works, even if incompatible types + override: str # type: ignore[override, assignment] - TestAlias = TestClass2[int, T] + TestAlias = TestClass2[int, T] # type: ignore[override] test_schema2 = class_schema(TestClass2)() assert list(test_schema2.fields) == ["pairs", "gen", "override"] assert isinstance(test_schema2.fields["pairs"], marshmallow.fields.List) @@ -452,7 +454,7 @@ class TestClass2(TestClass[str, W, U]): # inherit from alias @dataclasses.dataclass - class TestClass3(TestAlias[typing.List[int]]): + class TestClass3(TestAlias[typing.List[int]]): # type: ignore[override] pass test_schema3 = class_schema(TestClass3)() diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..819c87d --- /dev/null +++ b/tox.ini @@ -0,0 +1,31 @@ +[tox] +requires = + tox>=4 + virtualenv-pyenv +env_list = + py{38,39,310,311,312,313} + cover-report +set_env = + VIRTUALENV_DISCOVERY = pyenv + +[testenv] +deps = + coverage + pytest +commands = coverage run -p -m pytest tests +extras = dev +set_env = + VIRTUALENV_DISCOVERY = pyenv +depends = + cover-report: py{38,39,310,311,312,313} + +[testenv:cover-report] +skip_install = true +deps = coverage +commands = + coverage combine + coverage html + coverage report + + +# - You can also run `tox` from the command line to test in all supported python versions. Note that this will require you to have all supported python versions installed. \ No newline at end of file From fc66fc98b6e9c13fcd90deece7bcbb1b9ab4d31f Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sat, 1 Feb 2025 14:32:45 +0100 Subject: [PATCH 22/25] :bug: support py3.9 native collection types with generics. i.e.: list[T] --- marshmallow_dataclass/generic_resolver.py | 41 +++++++++++++---------- tests/test_generics.py | 23 +++++++++++-- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py index 3a1e556..7f7be72 100644 --- a/marshmallow_dataclass/generic_resolver.py +++ b/marshmallow_dataclass/generic_resolver.py @@ -1,6 +1,5 @@ import copy import dataclasses -import inspect import sys from typing import ( Dict, @@ -11,14 +10,18 @@ TypeVar, Union, ) - import typing_inspect +import warnings if sys.version_info >= (3, 9): from typing import Annotated, get_args, get_origin + from types import GenericAlias else: from typing_extensions import Annotated, get_args, get_origin + GenericAlias = type(list) + + if sys.version_info >= (3, 13): from typing import NoDefault else: @@ -194,25 +197,27 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: def _replace_typevars( clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None ) -> type: - if ( - not resolved_generics - or inspect.isclass(clazz) - or not may_contain_typevars(clazz) - ): + if not resolved_generics or not may_contain_typevars(clazz): return clazz - return clazz.copy_with( # type: ignore - tuple( - ( - _replace_typevars(arg, resolved_generics) - if may_contain_typevars(arg) - else ( - resolved_generics[arg].result() if arg in resolved_generics else arg - ) - ) - for arg in get_args(clazz) + new_args = tuple( + ( + _replace_typevars(arg, resolved_generics) + if may_contain_typevars(arg) + else (resolved_generics[arg].result() if arg in resolved_generics else arg) ) + for arg in get_args(clazz) ) + # i.e.: typing.List, typing.Dict, but not list, and dict + if hasattr(clazz, "copy_with"): + return clazz.copy_with(new_args) + # i.e.: list, dict - inspired by typing._strip_annotations + if sys.version_info >= (3, 9) and isinstance(clazz, GenericAlias): + return GenericAlias(clazz.__origin__, new_args) # type:ignore[return-value] + + # I'm not sure how we'd end up here. But raise a warnings so people can create an issue + warnings.warn(f"Unable to replace typevars in {clazz}") + return clazz def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: @@ -237,7 +242,7 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: # If it's a class we handle it later as a Nested. Nothing to resolve now. new_field = field field_type: type = field.type # type: ignore[assignment] - if not inspect.isclass(field_type) and may_contain_typevars(field_type): + if may_contain_typevars(field_type): new_field = copy.copy(field) new_field.type = _replace_typevars( field_type, resolved_typevars[subclass] diff --git a/tests/test_generics.py b/tests/test_generics.py index 8a9f2cc..2ac5e6f 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -183,6 +183,23 @@ class TestClass(typing.Generic[T, U]): test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) ) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python 3.9 or higher" + ) + def test_deep_generic_native(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: list[tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + def test_deep_generic_with_union(self): T = typing.TypeVar("T") U = typing.TypeVar("U") @@ -348,7 +365,7 @@ class OptionalGeneric(typing.Generic[T]): schema_s.load({"data": 2}) @pytest.mark.skipif( - sys.version_info <= (3, 13), reason="requires python 3.13 or higher" + sys.version_info < (3, 13), reason="requires python 3.13 or higher" ) def test_generic_default(self): T = typing.TypeVar("T", default=str) @@ -399,7 +416,7 @@ class NestedGeneric(typing.Generic[T]): schema_nested_generic.load({"data": {"data": "str"}}) @pytest.mark.skipif( - sys.version_info <= (3, 13), reason="requires python 3.13 or higher" + sys.version_info < (3, 13), reason="requires python 3.13 or higher" ) def test_deep_generic_with_default_overrides(self): T = typing.TypeVar("T", default=bool) @@ -482,7 +499,7 @@ class TestClass3(TestAlias[typing.List[int]]): # type: ignore[override] ) @pytest.mark.skipif( - sys.version_info <= (3, 13), reason="requires python 3.13 or higher" + sys.version_info < (3, 13), reason="requires python 3.13 or higher" ) def test_generic_default_recursion(self): T = typing.TypeVar("T", default=str) From 740fa490fa8cd219eb41b54dbd36b22928aae01f Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sat, 1 Feb 2025 15:24:09 +0100 Subject: [PATCH 23/25] :bug: fix mypy type issue --- marshmallow_dataclass/union_field.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/marshmallow_dataclass/union_field.py b/marshmallow_dataclass/union_field.py index ffe998d..d834aa9 100644 --- a/marshmallow_dataclass/union_field.py +++ b/marshmallow_dataclass/union_field.py @@ -1,6 +1,7 @@ import copy import inspect from typing import List, Tuple, Any, Optional +import typing import typeguard from marshmallow import fields, Schema, ValidationError @@ -43,7 +44,9 @@ def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs): super().__init__(**kwargs) self.union_fields = union_fields - def _bind_to_schema(self, field_name: str, schema: Schema) -> None: + def _bind_to_schema( + self, field_name: str, schema: typing.Union[Schema, fields.Field] + ) -> None: super()._bind_to_schema(field_name, schema) new_union_fields = [] for typ, field in self.union_fields: From 4e0f2141aa435f98cb505f4381a604347456930e Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sun, 2 Feb 2025 00:12:20 +0100 Subject: [PATCH 24/25] :bug: Generics did not work when schema was retrieved a second time. --- marshmallow_dataclass/__init__.py | 53 ++++++++++++++++++++++----- tests/test_generics.py | 61 +++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 10 deletions(-) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index 6aacd11..d694026 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -98,6 +98,29 @@ class User: MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 +class LazyGenericSchema: + """Exists to cache generic instances""" + + def __init__(self, base_schema, frame): + self.base_schema = base_schema + self.frame = frame + + self.__resolved_generic_schemas = {} + + def get_schema(self, instance): + instance_args = get_args(instance) + schema = self.__resolved_generic_schemas.get(instance_args) + if schema is None: + schema = class_schema( + instance, + self.base_schema, + self.frame, + ) + self.__resolved_generic_schemas[instance_args] = schema + + return schema + + def _maybe_get_callers_frame( cls: type, stacklevel: int = 1 ) -> Optional[types.FrameType]: @@ -294,12 +317,17 @@ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: else: frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel) - # noinspection PyTypeHints - clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, frame), - "Schema", - clazz.__name__, - ) + if not typing_inspect.is_generic_type(clazz): + # noinspection PyTypeHints + clazz.Schema = lazy_class_attribute( # type: ignore + partial(class_schema, clazz, base_schema, frame), + "Schema", + clazz.__name__, + ) + else: + # noinspection PyTypeHints + clazz.Schema = LazyGenericSchema(base_schema, frame) # type: ignore + return clazz if _cls is None: @@ -979,10 +1007,15 @@ def _field_for_schema( forward_reference = getattr(typ, "__forward_arg__", None) nested = ( - nested_schema - or forward_reference - or _schema_ctx_stack.top.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema) # type: ignore [arg-type] + # Pass the type instance. This is required for generics + nested_schema.get_schema(typ) + if isinstance(nested_schema, LazyGenericSchema) + else ( + nested_schema + or forward_reference + or _schema_ctx_stack.top.seen_classes.get(typ) + or _internal_class_schema(typ, base_schema) # type: ignore [arg-type] + ) ) return marshmallow.fields.Nested(nested, **metadata) diff --git a/tests/test_generics.py b/tests/test_generics.py index 2ac5e6f..09ebb63 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -114,6 +114,67 @@ class NestedGeneric(typing.Generic[T]): with self.assertRaises(ValidationError): schema_nested_generic.load({"data": {"data": "str"}}) + def test_generic_dataclass_cached(self): + T = typing.TypeVar("T") + + @dataclass + class SimpleGeneric(typing.Generic[T]): + data1: T + + @dataclass + class NestedFixed: + data2: SimpleGeneric[int] + + @dataclass + class NestedGeneric(typing.Generic[T]): + data3: SimpleGeneric[T] + + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data1="a"), schema_s.load({"data1": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data1="a")), {"data1": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data1": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data2=SimpleGeneric(1)), + schema_nested.load({"data2": {"data1": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data2=SimpleGeneric(data1=1))), + {"data2": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data2": {"data1": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data3=SimpleGeneric(1)), + schema_nested_generic.load({"data3": {"data1": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data3=SimpleGeneric(data1=1))), + {"data3": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data3": {"data1": "str"}}) + + # Copy test again so that we trigger a cache hit + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data3=SimpleGeneric(1)), + schema_nested_generic.load({"data3": {"data1": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data3=SimpleGeneric(data1=1))), + {"data3": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data3": {"data1": "str"}}) + def test_generic_dataclass_repeated_fields(self): T = typing.TypeVar("T") From b47f754955d8d55d848021fca150f9dac89f8f12 Mon Sep 17 00:00:00 2001 From: Michiel <918128+mvanderlee@users.noreply.github.com> Date: Sun, 2 Feb 2025 11:50:09 +0100 Subject: [PATCH 25/25] :bug: Ensure that Generic.Schema always throws a TypeError --- marshmallow_dataclass/__init__.py | 22 ++++++++++++++ tests/test_generics.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index d694026..241491c 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -107,6 +107,28 @@ def __init__(self, base_schema, frame): self.__resolved_generic_schemas = {} + def __call__(self): + """This get's called via `.Schema()`""" + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliases " + "(hint: use class_schema directly instead)" + ) + + def __get__(self, instance, cls=None): + # I haven't found a better way, but `inspect.getmember ` causes this function to be called with + # the __origin__ as second arg. This solutions seems to work best. + if instance is None and cls is not None: + return self + + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliases " + "(hint: use class_schema directly instead)" + ) + def get_schema(self, instance): instance_args = get_args(instance) schema = self.__resolved_generic_schemas.get(instance_args) diff --git a/tests/test_generics.py b/tests/test_generics.py index 09ebb63..a34de6b 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -150,6 +150,18 @@ class NestedGeneric(typing.Generic[T]): with self.assertRaises(ValidationError): schema_nested.load({"data2": {"data1": "str"}}) + schema_nested = NestedFixed.Schema() + self.assertEqual( + NestedFixed(data2=SimpleGeneric(1)), + schema_nested.load({"data2": {"data1": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data2=SimpleGeneric(data1=1))), + {"data2": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data2": {"data1": "str"}}) + schema_nested_generic = class_schema(NestedGeneric[int])() self.assertEqual( NestedGeneric(data3=SimpleGeneric(1)), @@ -175,6 +187,9 @@ class NestedGeneric(typing.Generic[T]): with self.assertRaises(ValidationError): schema_nested_generic.load({"data3": {"data1": "str"}}) + with self.assertRaisesRegex(TypeError, "generic"): + NestedGeneric.Schema() + def test_generic_dataclass_repeated_fields(self): T = typing.TypeVar("T") @@ -230,6 +245,26 @@ class GenClass(typing.Generic[T]): with self.assertRaisesRegex(TypeError, "generic"): add_schema(GenClass[int]) + def test_schema_raises_on_generic(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + import marshmallow_dataclass + + T = typing.TypeVar("T") + + @marshmallow_dataclass.dataclass + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + GenClass.Schema() + + with self.assertRaisesRegex(TypeError, "generic"): + GenClass[int].Schema() + def test_deep_generic(self): T = typing.TypeVar("T") U = typing.TypeVar("U") @@ -356,6 +391,19 @@ class Base: with self.assertRaises(TypeError): class_schema(Base) + def test_marshmallow_dataclass_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + def test_annotated_generic_mf_field(self) -> None: T = typing.TypeVar("T")