diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 0837335..bb7224b 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,7 +12,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest"] - python_version: ["3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.10"] + python_version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "pypy3.10"] runs-on: ${{ matrix.os }} steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1f1ea6..787e909 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - id: flake8 additional_dependencies: ['flake8-bugbear==22.10.27'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.1.1 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: [typeguard,marshmallow] diff --git a/marshmallow_dataclass/__init__.py b/marshmallow_dataclass/__init__.py index f82b13a..241491c 100644 --- a/marshmallow_dataclass/__init__.py +++ b/marshmallow_dataclass/__init__.py @@ -66,15 +66,19 @@ class User: ) import marshmallow -import typing_extensions import typing_inspect +from marshmallow_dataclass.generic_resolver import ( + UnboundTypeVarError, + get_generic_dataclass_fields, + is_generic_alias, +) from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute if sys.version_info >= (3, 9): - from typing import Annotated + from typing import Annotated, get_args, get_origin else: - from typing_extensions import Annotated + from typing_extensions import Annotated, get_args, get_origin if sys.version_info >= (3, 11): from typing import dataclass_transform @@ -94,6 +98,51 @@ class User: MAX_CLASS_SCHEMA_CACHE_SIZE = 1024 +class LazyGenericSchema: + """Exists to cache generic instances""" + + def __init__(self, base_schema, frame): + self.base_schema = base_schema + self.frame = frame + + self.__resolved_generic_schemas = {} + + def __call__(self): + """This get's called via `.Schema()`""" + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliases " + "(hint: use class_schema directly instead)" + ) + + def __get__(self, instance, cls=None): + # I haven't found a better way, but `inspect.getmember ` causes this function to be called with + # the __origin__ as second arg. This solutions seems to work best. + if instance is None and cls is not None: + return self + + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliases " + "(hint: use class_schema directly instead)" + ) + + def get_schema(self, instance): + instance_args = get_args(instance) + schema = self.__resolved_generic_schemas.get(instance_args) + if schema is None: + schema = class_schema( + instance, + self.base_schema, + self.frame, + ) + self.__resolved_generic_schemas[instance_args] = schema + + return schema + + def _maybe_get_callers_frame( cls: type, stacklevel: int = 1 ) -> Optional[types.FrameType]: @@ -139,6 +188,18 @@ def _maybe_get_callers_frame( del frame +def _check_decorated_type(cls: object) -> None: + if not isinstance(cls, type): + raise TypeError(f"expected a class not {cls!r}") + if is_generic_alias(cls): + # A .Schema attribute doesn't make sense on a generic alias — there's + # no way for it to know the generic parameters at run time. + raise TypeError( + "decorator does not support generic aliases " + "(hint: use class_schema directly instead)" + ) + + @overload def dataclass( _cls: Type[_U], @@ -214,12 +275,15 @@ def dataclass( ) def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]: + _check_decorated_type(cls) + return add_schema( dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1 ) if _cls is None: return decorator + return decorator(_cls, stacklevel=stacklevel + 1) @@ -268,17 +332,24 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1): """ def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]: + _check_decorated_type(clazz) + if cls_frame is not None: frame = cls_frame else: frame = _maybe_get_callers_frame(clazz, stacklevel=stacklevel) - # noinspection PyTypeHints - clazz.Schema = lazy_class_attribute( # type: ignore - partial(class_schema, clazz, base_schema, frame), - "Schema", - clazz.__name__, - ) + if not typing_inspect.is_generic_type(clazz): + # noinspection PyTypeHints + clazz.Schema = lazy_class_attribute( # type: ignore + partial(class_schema, clazz, base_schema, frame), + "Schema", + clazz.__name__, + ) + else: + # noinspection PyTypeHints + clazz.Schema = LazyGenericSchema(base_schema, frame) # type: ignore + return clazz if _cls is None: @@ -453,7 +524,7 @@ def class_schema( >>> class_schema(Custom)().load({}) Custom(name=None) """ - if not dataclasses.is_dataclass(clazz): + if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz): clazz = dataclasses.dataclass(clazz) if localns is None: if clazz_frame is None: @@ -514,17 +585,19 @@ def _internal_class_schema( ) -> Type[marshmallow.Schema]: schema_ctx = _schema_ctx_stack.top - if typing_extensions.get_origin(clazz) is Annotated and sys.version_info < (3, 10): + if get_origin(clazz) is Annotated and sys.version_info < (3, 10): # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977 class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined] else: - class_name = clazz.__name__ + # generic aliases do not have a __name__ prior python 3.10 + class_name = getattr(clazz, "__name__", repr(clazz)) schema_ctx.seen_classes[clazz] = class_name try: - # noinspection PyDataclass - fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz) + fields = _dataclass_fields(clazz) + except UnboundTypeVarError: + raise except TypeError: # Not a dataclass try: warnings.warn( @@ -540,6 +613,8 @@ def _internal_class_schema( ) created_dataclass: type = dataclasses.dataclass(clazz) return _internal_class_schema(created_dataclass, base_schema) + except UnboundTypeVarError: + raise except Exception as exc: raise TypeError( f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one." @@ -556,23 +631,19 @@ def _internal_class_schema( include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False) # Update the schema members to contain marshmallow fields instead of dataclass fields + type_hints = {} + if not typing_inspect.is_generic_type(clazz): + type_hints = _get_type_hints(clazz, schema_ctx) - if sys.version_info >= (3, 9): - type_hints = get_type_hints( - clazz, - globalns=schema_ctx.globalns, - localns=schema_ctx.localns, - include_extras=True, - ) - else: - type_hints = get_type_hints( - clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns - ) attributes.update( ( field.name, _field_for_schema( - type_hints[field.name], + ( + type_hints[field.name] + if not typing_inspect.is_generic_type(clazz) + else _resolve_forward_type_refs(field.type, schema_ctx) + ), _get_field_default(field), field.metadata, base_schema, @@ -582,7 +653,7 @@ def _internal_class_schema( if field.init or include_non_init ) - schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes) + schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes) return cast(Type[marshmallow.Schema], schema_class) @@ -633,8 +704,8 @@ def _field_by_supertype( ) -def _generic_type_add_any(typ: type) -> type: - """if typ is generic type without arguments, replace them by Any.""" +def _container_type_add_any(typ: type) -> type: + """if typ is container type without arguments, replace them by Any.""" if typ is list or typ is List: typ = List[Any] elif typ is dict or typ is Dict: @@ -650,18 +721,20 @@ def _generic_type_add_any(typ: type) -> type: return typ -def _field_for_generic_type( +def _field_for_container_type( typ: type, base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: """ - If the type is a generic interface, resolve the arguments and construct the appropriate Field. + If the type is a container interface, resolve the arguments and construct the appropriate Field. + + We use the term 'container' to differentiate from the Generic support """ - origin = typing_extensions.get_origin(typ) - arguments = typing_extensions.get_args(typ) + origin = get_origin(typ) + arguments = get_args(typ) if origin: - # Override base_schema.TYPE_MAPPING to change the class used for generic types below + # Override base_schema.TYPE_MAPPING to change the class used for container types below type_mapping = base_schema.TYPE_MAPPING if base_schema else {} if origin in (list, List): @@ -705,7 +778,7 @@ def _field_for_generic_type( ), ) return tuple_type(children, **metadata) - elif origin in (dict, Dict, collections.abc.Mapping, Mapping): + if origin in (dict, Dict, collections.abc.Mapping, Mapping): dict_type = type_mapping.get(Dict, marshmallow.fields.Dict) return dict_type( keys=_field_for_schema(arguments[0], base_schema=base_schema), @@ -723,14 +796,18 @@ def _field_for_annotated_type( """ If the type is an Annotated interface, resolve the arguments and construct the appropriate Field. """ - origin = typing_extensions.get_origin(typ) - arguments = typing_extensions.get_args(typ) + origin = get_origin(typ) + arguments = get_args(typ) if origin and origin is Annotated: marshmallow_annotations = [ arg for arg in arguments[1:] - if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field)) - or isinstance(arg, marshmallow.fields.Field) + if _is_marshmallow_field(arg) + # Support `CustomGenericField[mf.String]` + or ( + typing_inspect.is_generic_type(arg) + and _is_marshmallow_field(get_origin(arg)) + ) ] if marshmallow_annotations: if len(marshmallow_annotations) > 1: @@ -752,7 +829,7 @@ def _field_for_union_type( base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any, ) -> Optional[marshmallow.fields.Field]: - arguments = typing_extensions.get_args(typ) + arguments = get_args(typ) if typing_inspect.is_union_type(typ): if typing_inspect.is_optional_type(typ): metadata["allow_none"] = metadata.get("allow_none", True) @@ -838,6 +915,9 @@ def _field_for_schema( """ + if isinstance(typ, TypeVar): + raise UnboundTypeVarError(f"can not resolve type variable {typ.__name__}") + metadata = {} if metadata is None else dict(metadata) if default is not marshmallow.missing: @@ -853,8 +933,8 @@ def _field_for_schema( if predefined_field: return predefined_field - # Generic types specified without type arguments - typ = _generic_type_add_any(typ) + # Container types (generics like List) specified without type arguments + typ = _container_type_add_any(typ) # Base types field = _field_by_type(typ, base_schema) @@ -867,7 +947,7 @@ def _field_for_schema( # i.e.: Literal['abc'] if typing_inspect.is_literal_type(typ): - arguments = typing_inspect.get_args(typ) + arguments = get_args(typ) return marshmallow.fields.Raw( validate=( marshmallow.validate.Equal(arguments[0]) @@ -879,7 +959,7 @@ def _field_for_schema( # i.e.: Final[str] = 'abc' if typing_inspect.is_final_type(typ): - arguments = typing_inspect.get_args(typ) + arguments = get_args(typ) if arguments: subtyp = arguments[0] elif default is not marshmallow.missing: @@ -920,10 +1000,10 @@ def _field_for_schema( if union_field: return union_field - # Generic types - generic_field = _field_for_generic_type(typ, base_schema, **metadata) - if generic_field: - return generic_field + # Container types + container_field = _field_for_container_type(typ, base_schema, **metadata) + if container_field: + return container_field # typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a # __supertype__ attribute @@ -949,10 +1029,15 @@ def _field_for_schema( forward_reference = getattr(typ, "__forward_arg__", None) nested = ( - nested_schema - or forward_reference - or _schema_ctx_stack.top.seen_classes.get(typ) - or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME + # Pass the type instance. This is required for generics + nested_schema.get_schema(typ) + if isinstance(nested_schema, LazyGenericSchema) + else ( + nested_schema + or forward_reference + or _schema_ctx_stack.top.seen_classes.get(typ) + or _internal_class_schema(typ, base_schema) # type: ignore [arg-type] + ) ) return marshmallow.fields.Nested(nested, **metadata) @@ -996,6 +1081,61 @@ def _get_field_default(field: dataclasses.Field): return field.default +def is_generic_alias_of_dataclass(clazz: type) -> bool: + """ + Check if given class is a generic alias of a dataclass, if the dataclass is + defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed + """ + return is_generic_alias(clazz) and dataclasses.is_dataclass(get_origin(clazz)) + + +def _get_type_hints( + obj, + schema_ctx: _SchemaContext, +): + if sys.version_info >= (3, 9): + type_hints = get_type_hints( + obj, + globalns=schema_ctx.globalns, + localns=schema_ctx.localns, + include_extras=True, + ) + else: + type_hints = get_type_hints( + obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns + ) + + return type_hints + + +def _resolve_forward_type_refs( + obj, + schema_ctx: _SchemaContext, +) -> type: + """ + Resolve forward references, mainly applies to Generics i.e.: `A["int"]` -> `A[int]` + """ + + class X: + x: obj # type: ignore[name-defined] + + return _get_type_hints(X, schema_ctx)["x"] + + +def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + if not typing_inspect.is_generic_type(clazz): + return dataclasses.fields(clazz) + + else: + return get_generic_dataclass_fields(clazz) + + +def _is_marshmallow_field(obj) -> bool: + return ( + inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field) + ) or isinstance(obj, marshmallow.fields.Field) + + def NewType( name: str, typ: Type[_U], diff --git a/marshmallow_dataclass/generic_resolver.py b/marshmallow_dataclass/generic_resolver.py new file mode 100644 index 0000000..7f7be72 --- /dev/null +++ b/marshmallow_dataclass/generic_resolver.py @@ -0,0 +1,263 @@ +import copy +import dataclasses +import sys +from typing import ( + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, +) +import typing_inspect +import warnings + +if sys.version_info >= (3, 9): + from typing import Annotated, get_args, get_origin + from types import GenericAlias +else: + from typing_extensions import Annotated, get_args, get_origin + + GenericAlias = type(list) + + +if sys.version_info >= (3, 13): + from typing import NoDefault +else: + from typing import final + + @final + class NoDefault: + pass + + +_U = TypeVar("_U") + + +class UnboundTypeVarError(TypeError): + """TypeVar instance can not be resolved to a type spec. + + This exception is raised when an unbound TypeVar is encountered. + """ + + +class InvalidTypeVarDefaultError(TypeError): + """TypeVar default can not be resolved to a type spec. + + This exception is raised when an invalid TypeVar default is encountered. + This is most likely a scoping error: https://peps.python.org/pep-0696/#scoping-rules + """ + + +class InvalidStateError(Exception): + """Raised when an operation is performed on a future that is not + allowed in the current state. + """ + + +class _Future(Generic[_U]): + """The _Future class allows deferred access to a result that is not + yet available. + """ + + _done: bool + _result: _U + _default: Union[_U, "_Future[_U]"] + + def __init__(self, default=NoDefault) -> None: + self._done = False + self._default = default + + def done(self) -> bool: + """Return ``True`` if the value is available""" + return self._done + + def result(self) -> _U: + """Return the deferred value. + + Raises ``InvalidStateError`` if the value has not been set. + """ + if self.done(): + return self._result + + if self._default is not NoDefault: + if isinstance(self._default, _Future): + return self._default.result() + + return self._default + + raise InvalidStateError("result has not been set") + + def set_result(self, result: _U) -> None: + if self.done(): + raise InvalidStateError("result has already been set") + self._result = result + self._done = True + + +def is_generic_alias(clazz: type) -> bool: + """ + Check if given object is a Generic Alias. + + A `generic alias`__ is a generic type bound to generic parameters. + + E.g., given + + class A(Generic[T]): + pass + + ``A[int]`` is a _generic alias_ (while ``A`` is a *generic type*, but not a *generic alias*). + """ + is_generic = typing_inspect.is_generic_type(clazz) + type_arguments = get_args(clazz) + return is_generic and len(type_arguments) > 0 + + +def may_contain_typevars(clazz: type) -> bool: + """ + Check if the class can contain typevars. This includes Special Forms. + + Different from typing_inspect.is_generic_type as that explicitly ignores Union and Tuple. + + We still need to resolve typevars for Union and Tuple + """ + origin = get_origin(clazz) + return origin is not Annotated and ( + (isinstance(clazz, type) and issubclass(clazz, Generic)) # type: ignore[arg-type] + or isinstance(clazz, typing_inspect.typingGenericAlias) + ) + + +def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]: + """ + Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics. + + Returns a dict of each base class and the resolved generics. + """ + # Use Tuples so can zip (order matters) + args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {} + parent_class: Optional[type] = None + # Loop in reversed order and iteratively resolve types + for subclass in reversed(clazz.mro()): + if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type] + args = get_args(subclass.__orig_bases__[0]) + + if parent_class and args_by_class.get(parent_class): + subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = [] + for (_arg, future), potential_type in zip( + args_by_class[parent_class], args + ): + if isinstance(potential_type, TypeVar): + subclass_generic_params_to_args.append((potential_type, future)) + default = getattr(potential_type, "__default__", NoDefault) + if default is not None: + future._default = default + else: + future.set_result(potential_type) + + args_by_class[subclass] = tuple(subclass_generic_params_to_args) + + else: + # PEP-696: Typevar's may be used as defaults, but T1 must be used before T2 + # https://peps.python.org/pep-0696/#scoping-rules + seen_type_args: Dict[TypeVar, _Future] = {} + for arg in args: + default = getattr(arg, "__default__", NoDefault) + if default is not None: + if isinstance(default, TypeVar): + if default in seen_type_args: + # We've already seen this TypeVar, Set the default to it's _Future + default = seen_type_args[default] + + else: + # We haven't seen this yet, according to PEP-696 this is invalid. + raise InvalidTypeVarDefaultError( + f"{subclass.__name__} has an invalid TypeVar default for field {arg}" + ) + + seen_type_args[arg] = _Future(default=default) + + args_by_class[subclass] = tuple(seen_type_args.items()) + + parent_class = subclass + + # clazz itself is a generic alias i.e.: A[int]. So it hold the last types. + if is_generic_alias(clazz): + origin = get_origin(clazz) + args = get_args(clazz) + for (_arg, future), potential_type in zip(args_by_class[origin], args): # type: ignore[index] + if not isinstance(potential_type, TypeVar): + future.set_result(potential_type) + + # Convert to nested dict for easier lookup + return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()} + + +def _replace_typevars( + clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None +) -> type: + if not resolved_generics or not may_contain_typevars(clazz): + return clazz + + new_args = tuple( + ( + _replace_typevars(arg, resolved_generics) + if may_contain_typevars(arg) + else (resolved_generics[arg].result() if arg in resolved_generics else arg) + ) + for arg in get_args(clazz) + ) + # i.e.: typing.List, typing.Dict, but not list, and dict + if hasattr(clazz, "copy_with"): + return clazz.copy_with(new_args) + # i.e.: list, dict - inspired by typing._strip_annotations + if sys.version_info >= (3, 9) and isinstance(clazz, GenericAlias): + return GenericAlias(clazz.__origin__, new_args) # type:ignore[return-value] + + # I'm not sure how we'd end up here. But raise a warnings so people can create an issue + warnings.warn(f"Unable to replace typevars in {clazz}") + return clazz + + +def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]: + unbound_fields = set() + # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and + # looses the source class. Thus I don't know how to resolve this at later on. + # Instead we recreate the type but with all known TypeVars resolved to their actual types. + resolved_typevars = _resolve_typevars(clazz) + # Dict[field_name, Tuple[original_field, resolved_field]] + fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {} + + for subclass in reversed(clazz.mro()): + if not dataclasses.is_dataclass(subclass): + continue + + for field in dataclasses.fields(subclass): + try: + if field.name in fields and fields[field.name][0] == field: + continue # identical, so already resolved. + + # Either the first time we see this field, or it got overridden + # If it's a class we handle it later as a Nested. Nothing to resolve now. + new_field = field + field_type: type = field.type # type: ignore[assignment] + if may_contain_typevars(field_type): + new_field = copy.copy(field) + new_field.type = _replace_typevars( + field_type, resolved_typevars[subclass] + ) + elif isinstance(field_type, TypeVar): + new_field = copy.copy(field) + new_field.type = resolved_typevars[subclass][field_type].result() + + fields[field.name] = (field, new_field) + except InvalidStateError: + unbound_fields.add(field.name) + + if unbound_fields: + raise UnboundTypeVarError( + f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}" + ) + + return tuple(v[1] for v in fields.values()) diff --git a/marshmallow_dataclass/union_field.py b/marshmallow_dataclass/union_field.py index ffe998d..d834aa9 100644 --- a/marshmallow_dataclass/union_field.py +++ b/marshmallow_dataclass/union_field.py @@ -1,6 +1,7 @@ import copy import inspect from typing import List, Tuple, Any, Optional +import typing import typeguard from marshmallow import fields, Schema, ValidationError @@ -43,7 +44,9 @@ def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs): super().__init__(**kwargs) self.union_fields = union_fields - def _bind_to_schema(self, field_name: str, schema: Schema) -> None: + def _bind_to_schema( + self, field_name: str, schema: typing.Union[Schema, fields.Field] + ) -> None: super()._bind_to_schema(field_name, schema) new_union_fields = [] for typ, field in self.union_fields: diff --git a/setup.py b/setup.py index ceb2555..d6afb80 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", ] diff --git a/tests/test_class_schema.py b/tests/test_class_schema.py index 28185a4..8a1beb7 100644 --- a/tests/test_class_schema.py +++ b/tests/test_class_schema.py @@ -1,7 +1,7 @@ import inspect import typing import unittest -from typing import Any, cast, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, cast from uuid import UUID try: @@ -10,11 +10,14 @@ from typing_extensions import Final, Literal # type: ignore[assignment] import dataclasses + from marshmallow import Schema, ValidationError -from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer +from marshmallow.fields import UUID as UUIDField +from marshmallow.fields import Field, Integer +from marshmallow.fields import List as ListField from marshmallow.validate import Validator -from marshmallow_dataclass import class_schema, NewType +from marshmallow_dataclass import NewType, class_schema class TestClassSchema(unittest.TestCase): diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..a34de6b --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,638 @@ +import dataclasses +import inspect +import sys +import typing +import unittest +from typing_inspect import is_generic_type + +import marshmallow.fields +import pytest +from marshmallow import ValidationError + +from marshmallow_dataclass import ( + UnboundTypeVarError, + add_schema, + class_schema, + dataclass, + is_generic_alias_of_dataclass, +) + +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + + +def get_orig_class(obj): + """ + Allows you got get the runtime origin class inside __init__ + + Near duplicate of https://github.com/Stewori/pytypes/blob/master/pytypes/type_util.py#L182 + """ + try: + # See https://github.com/Stewori/pytypes/pull/53: + # Returns `obj.__orig_class__` protecting from infinite recursion in `__getattr[ibute]__` + # wrapped in a `checker_tp`. + # (See `checker_tp` in `typechecker._typeinspect_func for context) + # Necessary if: + # - we're wrapping a method (`obj` is `self`/`cls`) and either + # - the object's class defines __getattribute__ + # or + # - the object doesn't have an `__orig_class__` attribute + # and the object's class defines __getattr__. + # In such a situation, `parent_class = obj.__orig_class__` + # would call `__getattr[ibute]__`. But that method is wrapped in a `checker_tp` too, + # so then we'd go into the wrapped `__getattr[ibute]__` and do + # `parent_class = obj.__orig_class__`, which would call `__getattr[ibute]__` + # again, and so on. So to bypass `__getattr[ibute]__` we do this: + return object.__getattribute__(obj, "__orig_class__") + except AttributeError: + cls = object.__getattribute__(obj, "__class__") + if is_generic_type(cls): + # Searching from index 1 is sufficient: At 0 is get_orig_class, at 1 is the caller. + frame = inspect.currentframe().f_back + try: + while frame: + try: + res = frame.f_locals["self"] + if res.__origin__ is cls: + return res + except (KeyError, AttributeError): + frame = frame.f_back + finally: + del frame + + raise + + +class TestGenerics(unittest.TestCase): + def test_generic_dataclass(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + def test_generic_dataclass_cached(self): + T = typing.TypeVar("T") + + @dataclass + class SimpleGeneric(typing.Generic[T]): + data1: T + + @dataclass + class NestedFixed: + data2: SimpleGeneric[int] + + @dataclass + class NestedGeneric(typing.Generic[T]): + data3: SimpleGeneric[T] + + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric[str])() + self.assertEqual(SimpleGeneric(data1="a"), schema_s.load({"data1": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data1="a")), {"data1": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data1": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data2=SimpleGeneric(1)), + schema_nested.load({"data2": {"data1": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data2=SimpleGeneric(data1=1))), + {"data2": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data2": {"data1": "str"}}) + + schema_nested = NestedFixed.Schema() + self.assertEqual( + NestedFixed(data2=SimpleGeneric(1)), + schema_nested.load({"data2": {"data1": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data2=SimpleGeneric(data1=1))), + {"data2": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data2": {"data1": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data3=SimpleGeneric(1)), + schema_nested_generic.load({"data3": {"data1": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data3=SimpleGeneric(data1=1))), + {"data3": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data3": {"data1": "str"}}) + + # Copy test again so that we trigger a cache hit + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data3=SimpleGeneric(1)), + schema_nested_generic.load({"data3": {"data1": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data3=SimpleGeneric(data1=1))), + {"data3": {"data1": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data3": {"data1": "str"}}) + + with self.assertRaisesRegex(TypeError, "generic"): + NestedGeneric.Schema() + + def test_generic_dataclass_repeated_fields(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class AA: + a: int + + @dataclasses.dataclass + class BB(typing.Generic[T]): + b: T + + @dataclasses.dataclass + class Nested: + x: BB[float] + z: BB[float] + # if y is the first field in this class, deserialisation will fail. + # see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027 + y: BB[AA] + + schema_nested = class_schema(Nested)() + self.assertEqual( + Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))), + schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}), + ) + + def test_marshmallow_dataclass_decorator_raises_on_generic_alias(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + import marshmallow_dataclass + + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + marshmallow_dataclass.dataclass(GenClass[int]) + + def test_add_schema_raises_on_generic_alias(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + T = typing.TypeVar("T") + + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + add_schema(GenClass[int]) + + def test_schema_raises_on_generic(self): + """ + We can't support `GenClass[int].Schema` because the class function was created on `GenClass` + Therefore the function does not know about the `int` type. + This is a Python limitation, not a marshmallow_dataclass limitation. + """ + import marshmallow_dataclass + + T = typing.TypeVar("T") + + @marshmallow_dataclass.dataclass + class GenClass(typing.Generic[T]): + pass + + with self.assertRaisesRegex(TypeError, "generic"): + GenClass.Schema() + + with self.assertRaisesRegex(TypeError, "generic"): + GenClass[int].Schema() + + def test_deep_generic(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: typing.List[typing.Tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python 3.9 or higher" + ) + def test_deep_generic_native(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + pairs: list[tuple[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"pairs": [("first", "1")]}), TestClass([("first", 1)]) + ) + + def test_deep_generic_with_union(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U]): + either: typing.List[typing.Union[T, U]] + + test_schema = class_schema(TestClass[str, int])() + + self.assertEqual( + test_schema.load({"either": ["first", 1]}), TestClass(["first", 1]) + ) + + def test_deep_generic_with_overrides(self): + T = typing.TypeVar("T") + U = typing.TypeVar("U") + V = typing.TypeVar("V") + W = typing.TypeVar("W") + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U, V]): + pairs: typing.List[typing.Tuple[T, U]] + gen: V + override: int + + # Don't only override typevar, but switch order to further confuse things + # Ignoring 'override' Because I want to test that it works, even if incompatible types + @dataclasses.dataclass + class TestClass2(TestClass[str, W, U]): # type: ignore[override] + override: str # type: ignore[override, assignment] + + TestAlias = TestClass2[int, T] # type: ignore[override] + + # inherit from alias + @dataclasses.dataclass + class TestClass3(TestAlias[typing.List[int]]): # type: ignore[override] + pass + + test_schema = class_schema(TestClass3)() + + self.assertEqual( + test_schema.load( + {"pairs": [("first", "1")], "gen": ["1", 2], "override": "overridden"} + ), + TestClass3([("first", 1)], [1, 2], "overridden"), + ) + + def test_generic_bases(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[T]): + pass + + test_schema = class_schema(TestClass[int])() + + self.assertEqual(test_schema.load({"answer": "1"}), TestClass(1)) + + def test_bound_generic_base(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base1(typing.Generic[T]): + answer: T + + @dataclasses.dataclass + class TestClass(Base1[int]): + pass + + with self.assertRaisesRegex( + UnboundTypeVarError, "Base1 has unbound fields: answer" + ): + class_schema(Base1) + + test_schema = class_schema(TestClass)() + self.assertEqual(test_schema.load({"answer": "1"}), TestClass(1)) + + def test_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclasses.dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + + def test_marshmallow_dataclass_unbound_type_var(self) -> None: + T = typing.TypeVar("T") + + @dataclass + class Base: + answer: T # type: ignore[valid-type] + + with self.assertRaises(UnboundTypeVarError): + class_schema(Base) + + with self.assertRaises(TypeError): + class_schema(Base) + + def test_annotated_generic_mf_field(self) -> None: + T = typing.TypeVar("T") + + class GenericList(marshmallow.fields.List, typing.Generic[T]): + """ + Generic Marshmallow List Field that can be used in Annotated and still get all kwargs + from marshmallow_dataclass. + """ + + def __init__( + self, + **kwargs, + ): + cls_or_instance = get_orig_class(self).__args__[0] + + super().__init__(cls_or_instance, **kwargs) + + @dataclass + class AnnotatedValue: + emails: Annotated[ + typing.List[str], GenericList[marshmallow.fields.Email] + ] = dataclasses.field(default_factory=lambda: ["default@email.com"]) + + schema = AnnotatedValue.Schema() # type: ignore[attr-defined] + + self.assertEqual( + schema.load({}), + AnnotatedValue(emails=["default@email.com"]), + ) + self.assertEqual( + schema.load({"emails": ["test@test.com"]}), + AnnotatedValue( + emails=["test@test.com"], + ), + ) + + with self.assertRaises(marshmallow.exceptions.ValidationError): + schema.load({"emails": "notavalidemail"}) + + def test_generic_dataclass_with_forwardref(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class ForwardGeneric(typing.Generic[T]): + data: T + + schema_s = class_schema(ForwardGeneric["str"])() + self.assertEqual(ForwardGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(ForwardGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + def test_generic_dataclass_with_optional(self): + T = typing.TypeVar("T") + + @dataclasses.dataclass + class OptionalGeneric(typing.Generic[T]): + data: typing.Optional[T] + + schema_s = class_schema(OptionalGeneric["str"])() + self.assertEqual(OptionalGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(OptionalGeneric(data="a")), {"data": "a"}) + + self.assertEqual(OptionalGeneric(data=None), schema_s.load({})) + self.assertEqual(schema_s.dump(OptionalGeneric(data=None)), {"data": None}) + + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + @pytest.mark.skipif( + sys.version_info < (3, 13), reason="requires python 3.13 or higher" + ) + def test_generic_default(self): + T = typing.TypeVar("T", default=str) + + @dataclasses.dataclass + class SimpleGeneric(typing.Generic[T]): + data: T + + @dataclasses.dataclass + class NestedFixed: + data: SimpleGeneric[int] + + @dataclasses.dataclass + class NestedGeneric(typing.Generic[T]): + data: SimpleGeneric[T] + + self.assertTrue(is_generic_alias_of_dataclass(SimpleGeneric[int])) + self.assertFalse(is_generic_alias_of_dataclass(SimpleGeneric)) + + schema_s = class_schema(SimpleGeneric)() + self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"})) + self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"}) + with self.assertRaises(ValidationError): + schema_s.load({"data": 2}) + + schema_nested = class_schema(NestedFixed)() + self.assertEqual( + NestedFixed(data=SimpleGeneric(1)), + schema_nested.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested.load({"data": {"data": "str"}}) + + schema_nested_generic = class_schema(NestedGeneric[int])() + self.assertEqual( + NestedGeneric(data=SimpleGeneric(1)), + schema_nested_generic.load({"data": {"data": 1}}), + ) + self.assertEqual( + schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))), + {"data": {"data": 1}}, + ) + with self.assertRaises(ValidationError): + schema_nested_generic.load({"data": {"data": "str"}}) + + @pytest.mark.skipif( + sys.version_info < (3, 13), reason="requires python 3.13 or higher" + ) + def test_deep_generic_with_default_overrides(self): + T = typing.TypeVar("T", default=bool) + U = typing.TypeVar("U", default=int) + V = typing.TypeVar("V", default=str) + W = typing.TypeVar("W", default=float) + + @dataclasses.dataclass + class TestClass(typing.Generic[T, U, V]): + pairs: typing.List[typing.Tuple[T, U]] + gen: V + override: int + + test_schema = class_schema(TestClass)() + assert list(test_schema.fields) == ["pairs", "gen", "override"] + assert isinstance(test_schema.fields["pairs"], marshmallow.fields.List) + assert isinstance(test_schema.fields["pairs"].inner, marshmallow.fields.Tuple) + assert isinstance( + test_schema.fields["pairs"].inner.tuple_fields[0], + marshmallow.fields.Boolean, + ) + assert isinstance( + test_schema.fields["pairs"].inner.tuple_fields[1], + marshmallow.fields.Integer, + ) + + assert isinstance(test_schema.fields["gen"], marshmallow.fields.String) + assert isinstance(test_schema.fields["override"], marshmallow.fields.Integer) + + # Don't only override typevar, but switch order to further confuse things + @dataclasses.dataclass + class TestClass2(TestClass[str, W, U]): # type: ignore[override] + # Want to test that it works, even if incompatible types + override: str # type: ignore[override, assignment] + + TestAlias = TestClass2[int, T] # type: ignore[override] + test_schema2 = class_schema(TestClass2)() + assert list(test_schema2.fields) == ["pairs", "gen", "override"] + assert isinstance(test_schema2.fields["pairs"], marshmallow.fields.List) + assert isinstance(test_schema2.fields["pairs"].inner, marshmallow.fields.Tuple) + assert isinstance( + test_schema2.fields["pairs"].inner.tuple_fields[0], + marshmallow.fields.String, + ) + assert isinstance( + test_schema2.fields["pairs"].inner.tuple_fields[1], + marshmallow.fields.Float, + ) + + assert isinstance(test_schema2.fields["gen"], marshmallow.fields.Integer) + assert isinstance(test_schema2.fields["override"], marshmallow.fields.String) + + # inherit from alias + @dataclasses.dataclass + class TestClass3(TestAlias[typing.List[int]]): # type: ignore[override] + pass + + test_schema3 = class_schema(TestClass3)() + assert list(test_schema3.fields) == ["pairs", "gen", "override"] + assert isinstance(test_schema3.fields["pairs"], marshmallow.fields.List) + assert isinstance(test_schema3.fields["pairs"].inner, marshmallow.fields.Tuple) + assert isinstance( + test_schema3.fields["pairs"].inner.tuple_fields[0], + marshmallow.fields.String, + ) + assert isinstance( + test_schema3.fields["pairs"].inner.tuple_fields[1], + marshmallow.fields.Integer, + ) + + assert isinstance(test_schema3.fields["gen"], marshmallow.fields.List) + assert isinstance(test_schema3.fields["gen"].inner, marshmallow.fields.Integer) + assert isinstance(test_schema3.fields["override"], marshmallow.fields.String) + + self.assertEqual( + test_schema3.load( + {"pairs": [("first", "1")], "gen": ["1", 2], "override": "overridden"} + ), + TestClass3([("first", 1)], [1, 2], "overridden"), + ) + + @pytest.mark.skipif( + sys.version_info < (3, 13), reason="requires python 3.13 or higher" + ) + def test_generic_default_recursion(self): + T = typing.TypeVar("T", default=str) + U = typing.TypeVar("U", default=T) + V = typing.TypeVar("V", default=U) + + @dataclasses.dataclass + class DefaultGenerics(typing.Generic[T, U, V]): + a: T + b: U + c: V + + test_schema = class_schema(DefaultGenerics)() + assert list(test_schema.fields) == ["a", "b", "c"] + assert isinstance(test_schema.fields["a"], marshmallow.fields.String) + assert isinstance(test_schema.fields["b"], marshmallow.fields.String) + assert isinstance(test_schema.fields["c"], marshmallow.fields.String) + + test_schema2 = class_schema(DefaultGenerics[int])() + assert list(test_schema2.fields) == ["a", "b", "c"] + assert isinstance(test_schema2.fields["a"], marshmallow.fields.Integer) + assert isinstance(test_schema2.fields["b"], marshmallow.fields.Integer) + assert isinstance(test_schema2.fields["c"], marshmallow.fields.Integer) + + +if __name__ == "__main__": + unittest.main() diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..819c87d --- /dev/null +++ b/tox.ini @@ -0,0 +1,31 @@ +[tox] +requires = + tox>=4 + virtualenv-pyenv +env_list = + py{38,39,310,311,312,313} + cover-report +set_env = + VIRTUALENV_DISCOVERY = pyenv + +[testenv] +deps = + coverage + pytest +commands = coverage run -p -m pytest tests +extras = dev +set_env = + VIRTUALENV_DISCOVERY = pyenv +depends = + cover-report: py{38,39,310,311,312,313} + +[testenv:cover-report] +skip_install = true +deps = coverage +commands = + coverage combine + coverage html + coverage report + + +# - You can also run `tox` from the command line to test in all supported python versions. Note that this will require you to have all supported python versions installed. \ No newline at end of file