Skip to content

Add Generic dataclasses #259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bc46a23
support generic dataclasses
onursatici Jan 11, 2022
db32163
support nested generic dataclasses
onursatici Nov 28, 2022
8b82276
add test for repeated fields, fix __name__ attr for py<3.10
onursatici Dec 9, 2022
f2734cc
support py3.6
onursatici Dec 15, 2022
9a5dc09
Split generic tests into it's own file.
mvanderlee Mar 10, 2024
8443336
Add support for deep generics with swapped TypeVars.
mvanderlee Mar 11, 2024
8a0f837
Fix tests after rebase
mvanderlee Jun 24, 2024
f484596
Remove unnecessary whitespace
mvanderlee Jun 24, 2024
80dab91
fix call correct _field_for_schema function
mvanderlee Jun 24, 2024
4531c35
Break generic functions out into it's own file and add support for an…
mvanderlee Jun 25, 2024
dd34efc
Remove support for callable annotations
mvanderlee Jun 26, 2024
7ac088d
Remove support for annotated partials
mvanderlee Jun 26, 2024
b3362ba
Fix import style and some docstrings, and reuse is_generic_alias inst…
mvanderlee Jun 27, 2024
db95e64
Rename function to be more descriptive
mvanderlee Jun 27, 2024
a494984
Improved doc string
mvanderlee Jun 27, 2024
78fcd4a
Clean up
mvanderlee Jun 27, 2024
8797b2b
Rename our `is_generic_type` and only utilize where absolutely necessary
mvanderlee Jun 27, 2024
c361f3a
Clean up unnessary if statements and redundant function call
mvanderlee Sep 28, 2024
c3f5da1
Add support for TypeVar defaults
mvanderlee Sep 29, 2024
231b3b2
fix: Use Union compatible with <3.10
mvanderlee Sep 29, 2024
2ef5a71
Add python 3.13
mvanderlee Nov 30, 2024
fc66fc9
:bug: support py3.9 native collection types with generics. i.e.: list[T]
mvanderlee Feb 1, 2025
740fa49
:bug: fix mypy type issue
mvanderlee Feb 1, 2025
4e0f214
:bug: Generics did not work when schema was retrieved a second time.
mvanderlee Feb 1, 2025
b47f754
:bug: Ensure that Generic.Schema always throws a TypeError
mvanderlee Feb 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 118 additions & 31 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,9 @@ class User:
import warnings
from enum import Enum
from functools import lru_cache, partial
from typing import Any, Callable, Dict, FrozenSet, Generic, List, Mapping
from typing import NewType as typing_NewType
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Generic,
List,
Mapping,
NewType as typing_NewType,
Optional,
Sequence,
Set,
Expand All @@ -69,6 +63,12 @@ class User:
import typing_extensions
import typing_inspect

from marshmallow_dataclass.generic_resolver import (
UnboundTypeVarError,
get_generic_dataclass_fields,
is_generic_alias,
is_generic_type,
)
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute

if sys.version_info >= (3, 9):
Expand Down Expand Up @@ -139,6 +139,18 @@ def _maybe_get_callers_frame(
del frame


def _check_decorated_type(cls: object) -> None:
if not isinstance(cls, type):
raise TypeError(f"expected a class not {cls!r}")
if is_generic_alias(cls):
# A .Schema attribute doesn't make sense on a generic alias — there's
# no way for it to know the generic parameters at run time.
raise TypeError(
"decorator does not support generic aliasses "
"(hint: use class_schema directly instead)"
)


@overload
def dataclass(
_cls: Type[_U],
Expand Down Expand Up @@ -214,12 +226,18 @@ def dataclass(
)

def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
if cls is not None:
_check_decorated_type(cls)

return add_schema(
dc(cls), base_schema, cls_frame=cls_frame, stacklevel=stacklevel + 1
)

if _cls is None:
return decorator

if _cls is not None:
_check_decorated_type(_cls)
return decorator(_cls, stacklevel=stacklevel + 1)


Expand Down Expand Up @@ -268,6 +286,8 @@ def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1):
"""

def decorator(clazz: Type[_U], stacklevel: int = stacklevel) -> Type[_U]:
_check_decorated_type(clazz)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_check_decorate_type requires only that isinstance(clazz, type).
Do we want to require that clazz is a dataclass (isinstance(clazz, type) and dataclasses.is_dataclass(clazz)) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because we already have a big warning when a class is not a dataclass in _internal_class_schema
So non-dataclasses are allowed, but not supported

Copy link
Collaborator

@dairiki dairiki Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just run a simple test.

You are correct in that decorating non-data class classes with add_schema works — at least in simple cases. (That doesn't necessarily mean that we should allow it.)

It appears, however, that, as things stand in this PR, no big warning is emitted in that case. Further investigation reveals that get_resolved_dataclass_fields "just works" (with no warnings emitted) even if its argument is not a dataclass. (I have some suspicion that perhaps fields from the classes __mro__ are not properly handled in that case, but I haven't looked close enough to say for sure.)

In any case, I think that either:

  • We should disallow using the add_schema decorator on non-dataclasses. (Why allow it if it's unsupported/untested?)
  • Or, we need a test to ensure the warning is, in fact, emitted when it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had to do some digging.

  1. In today's version, 8.7.1, I can call class_schema(NonDataclass) just fine. It only shows the warning if one of it's fields is not a dataclass. i.e.: Nested non dataclasses.
  2. This goes back to when the warning was originally added: e31faa8
  3. The behaviour still works the same with this PR.

I don't disagree with removing support for non-dataclasses, but don't see why that should be part of this PR.


if cls_frame is not None:
frame = cls_frame
else:
Expand Down Expand Up @@ -453,7 +473,7 @@ def class_schema(
>>> class_schema(Custom)().load({})
Custom(name=None)
"""
if not dataclasses.is_dataclass(clazz):
if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz):
clazz = dataclasses.dataclass(clazz)
if localns is None:
if clazz_frame is None:
Expand Down Expand Up @@ -518,13 +538,15 @@ def _internal_class_schema(
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
else:
class_name = clazz.__name__
# generic aliases do not have a __name__ prior python 3.10
class_name = getattr(clazz, "__name__", repr(clazz))

schema_ctx.seen_classes[clazz] = class_name

try:
# noinspection PyDataclass
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
fields = _dataclass_fields(clazz)
except UnboundTypeVarError:
raise
except TypeError: # Not a dataclass
try:
warnings.warn(
Expand All @@ -540,6 +562,8 @@ def _internal_class_schema(
)
created_dataclass: type = dataclasses.dataclass(clazz)
return _internal_class_schema(created_dataclass, base_schema)
except UnboundTypeVarError:
raise
except Exception as exc:
raise TypeError(
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
Expand All @@ -556,23 +580,19 @@ def _internal_class_schema(
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)

# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = {}
if not is_generic_type(clazz):
type_hints = _get_type_hints(clazz, schema_ctx)

if sys.version_info >= (3, 9):
type_hints = get_type_hints(
clazz,
globalns=schema_ctx.globalns,
localns=schema_ctx.localns,
include_extras=True,
)
else:
type_hints = get_type_hints(
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)
attributes.update(
(
field.name,
_field_for_schema(
type_hints[field.name],
(
type_hints[field.name]
if not is_generic_type(clazz)
else _get_generic_type_hints(field.type, schema_ctx)
),
_get_field_default(field),
field.metadata,
base_schema,
Expand All @@ -582,7 +602,7 @@ def _internal_class_schema(
if field.init or include_non_init
)

schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
return cast(Type[marshmallow.Schema], schema_class)


Expand Down Expand Up @@ -705,7 +725,7 @@ def _field_for_generic_type(
),
)
return tuple_type(children, **metadata)
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
return dict_type(
keys=_field_for_schema(arguments[0], base_schema=base_schema),
Expand All @@ -729,8 +749,12 @@ def _field_for_annotated_type(
marshmallow_annotations = [
arg
for arg in arguments[1:]
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
or isinstance(arg, marshmallow.fields.Field)
if _is_marshmallow_field(arg)
# Support `CustomGenericField[mf.String]`
or (
is_generic_type(arg)
and _is_marshmallow_field(typing_extensions.get_origin(arg))
)
]
if marshmallow_annotations:
if len(marshmallow_annotations) > 1:
Expand Down Expand Up @@ -838,6 +862,9 @@ def _field_for_schema(

"""

if isinstance(typ, TypeVar):
raise UnboundTypeVarError(f"can not resolve type variable {typ.__name__}")

metadata = {} if metadata is None else dict(metadata)

if default is not marshmallow.missing:
Expand Down Expand Up @@ -867,7 +894,7 @@ def _field_for_schema(

# i.e.: Literal['abc']
if typing_inspect.is_literal_type(typ):
arguments = typing_inspect.get_args(typ)
arguments = typing_extensions.get_args(typ)
return marshmallow.fields.Raw(
validate=(
marshmallow.validate.Equal(arguments[0])
Expand All @@ -879,7 +906,7 @@ def _field_for_schema(

# i.e.: Final[str] = 'abc'
if typing_inspect.is_final_type(typ):
arguments = typing_inspect.get_args(typ)
arguments = typing_extensions.get_args(typ)
if arguments:
subtyp = arguments[0]
elif default is not marshmallow.missing:
Expand Down Expand Up @@ -952,7 +979,7 @@ def _field_for_schema(
nested_schema
or forward_reference
or _schema_ctx_stack.top.seen_classes.get(typ)
or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME
or _internal_class_schema(typ, base_schema) # type: ignore [arg-type]
)

return marshmallow.fields.Nested(nested, **metadata)
Expand Down Expand Up @@ -996,6 +1023,66 @@ def _get_field_default(field: dataclasses.Field):
return field.default


def is_generic_alias_of_dataclass(clazz: type) -> bool:
"""
Check if given class is a generic alias of a dataclass, if the dataclass is
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
"""
is_generic = is_generic_type(clazz)
type_arguments = typing_extensions.get_args(clazz)
origin_class = typing_extensions.get_origin(clazz)
return (
is_generic
and len(type_arguments) > 0
and dataclasses.is_dataclass(origin_class)
)


def _get_type_hints(
obj,
schema_ctx: _SchemaContext,
):
if sys.version_info >= (3, 9):
type_hints = get_type_hints(
obj,
globalns=schema_ctx.globalns,
localns=schema_ctx.localns,
include_extras=True,
)
else:
type_hints = get_type_hints(
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
)

return type_hints


def _get_generic_type_hints(
obj,
schema_ctx: _SchemaContext,
) -> type:
"""typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this function should be renamed _resolve_forward_type_refs (or similar).

Get_type_hints returns a dict containing the type hints for all members of a class (while resolving forward type references found along the way).
The purpose of this function, by way of contrast, is to resolve forward type references in a single type hint. (If there are no forward type references in obj, this function (I think) returns obj unchanged.)

(Spelling nit: "aliases")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the forward refs have already been resolved. This really exists purely to get the typehint for generics.
maybe _get_type_hint_of_generic_object ?

import typing

T = typing.TypeVar('T')

class A(Generic[T]):
  a: T

class B(A[int]):
  pass

print(typing.get_type_hints(B))
print(typing.get_type_hints(A[int]))

>=========================
{'a': ~T}
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[51], [line 12](vscode-notebook-cell:?execution_count=51&line=12)
      [9](vscode-notebook-cell:?execution_count=51&line=9)   pass
     [11](vscode-notebook-cell:?execution_count=51&line=11) print(typing.get_type_hints(B))
---> [12](vscode-notebook-cell:?execution_count=51&line=12) print(typing.get_type_hints(A[int]))

File ~\.pyenv\pyenv-win\versions\3.11.3\Lib\typing.py:2347, in get_type_hints(obj, globalns, localns, include_extras)
   ~/.pyenv/pyenv-win/versions/3.11.3/Lib/typing.py:2345)         return {}
   ~/.pyenv/pyenv-win/versions/3.11.3/Lib/typing.py:2346)     else:
-> ~/.pyenv/pyenv-win/versions/3.11.3/Lib/typing.py:2347)         raise TypeError('{!r} is not a module, class, method, '
   ~/.pyenv/pyenv-win/versions/3.11.3/Lib/typing.py:2348)                         'or function.'.format(obj))
   ~/.pyenv/pyenv-win/versions/3.11.3/Lib/typing.py:2349) hints = dict(hints)
   ~/.pyenv/pyenv-win/versions/3.11.3/Lib/typing.py:2350) for name, value in hints.items():

TypeError: __main__.A[int] is not a module, class, method, or function.
def _get_generic_type_hints(obj) -> type:
    """typing.get_type_hints doesn't work with generic aliases. But this 'hack' works."""

    class X:
        x: obj  # type: ignore[name-defined]

    return typing.get_type_hints(X)['x']

print(_get_generic_type_hints(B))
print(_get_generic_type_hints(A[int]))

>=========================
<class '__main__.B'>
__main__.A[int]

Copy link
Collaborator

@dairiki dairiki Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your last examples, _get_generic_type_hints is just the identity, right?

>>> _get_generic_type_hints(B) is B
True
>>> _get_generic_type_hints(A[int]) is A[int]
True

But, the value of _get_generic_type_hints comes in resolving forward references:

>>> _get_generic_type_hints(A["int"])
__main__.A[int]
>>> _get_generic_type_hints(A["int"]) is A[int]
False

Copy link
Collaborator

@dairiki dairiki Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe what is really wanted is (untested)

def _get_generic_type_hints(generic_alias):
    class X(generic_alias): pass

    return typing.get_type_hints(X)

Or, in one line:

    return typing.get_type_hints(types.new_class("_", bases=(generic_alias,)))

That version of _get_generic_type_hints then matches the signature of typing.get_type_hints. As things stand, it does not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, good catch! You're absolutely right.

  1. This just worked 🤦‍♂️ I clearly spend too long trying to make it work to see the obvious solution.
(
    type_hints[field.name]
    if not is_generic_type(clazz)
    else field.type
)
  1. But didn't work for A["int"] and no tests caught that yet. So I'll add a tests and rename the function as per your recommendation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, if that works, maybe just roll this into _get_type_hints so that there's a single _get_type_hints function that can be used for any dataclass, be it generic or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _get_generic_type_hints(generic_alias):
    class X(generic_alias): pass

    return typing.get_type_hints(X)

Does not work.

print(_get_generic_type_hints(A[int]))

>=====================
{'a': ~T}

Instead of {'a': __main__.A[int]}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does not work.

That's too bad.

I still think that refactoring so that _get_type_hints will work with generic aliases of dataclasses as well as plain dataclasses is probably worth it. It will move all the is_generic_type special-casing out of

type_hints = {}
if not is_generic_type(clazz):
type_hints = _get_type_hints(clazz, schema_ctx)
attributes.update(
(
field.name,
_field_for_schema(
(
type_hints[field.name]
if not is_generic_type(clazz)
else _get_generic_type_hints(field.type, schema_ctx)
),
_get_field_default(field),
field.metadata,
base_schema,
),
)
for field in fields
if field.init or include_non_init
)

Copy link
Contributor Author

@mvanderlee mvanderlee Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't disagree but I've already sunk too many hours into this particular problem. If it is even possible to do, it'll have to be someone else as I don't have the knowledge to take this further.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dairiki I found a way around it by resolving the forward refs when we get the dataclass fields. We now no longer call get_type_hints at all anymore as it's relevant code is internalized.

We were already looping over the mro and fields just like get_type_hints.


class X:
x: obj # type: ignore[name-defined]

return _get_type_hints(X, schema_ctx)["x"]


def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
if not is_generic_type(clazz):
return dataclasses.fields(clazz)

else:
return get_generic_dataclass_fields(clazz)


def _is_marshmallow_field(obj) -> bool:
return (
inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field)
) or isinstance(obj, marshmallow.fields.Field)


def NewType(
name: str,
typ: Type[_U],
Expand Down
Loading