Skip to content

Commit 4531c35

Browse files
committed
Break generic functions out into it's own file and add support for annotated generics, partials, and callables
1 parent 80dab91 commit 4531c35

File tree

4 files changed

+392
-185
lines changed

4 files changed

+392
-185
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 41 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class User:
3636
"""
3737

3838
import collections.abc
39-
import copy
4039
import dataclasses
4140
import inspect
4241
import sys
@@ -64,6 +63,12 @@ class User:
6463
import typing_extensions
6564
import typing_inspect
6665

66+
from marshmallow_dataclass.generic_resolver import (
67+
UnboundTypeVarError,
68+
get_generic_dataclass_fields,
69+
is_generic_alias,
70+
is_generic_type,
71+
)
6772
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute
6873

6974
if sys.version_info >= (3, 9):
@@ -134,55 +139,10 @@ def _maybe_get_callers_frame(
134139
del frame
135140

136141

137-
class UnboundTypeVarError(TypeError):
138-
"""TypeVar instance can not be resolved to a type spec.
139-
140-
This exception is raised when an unbound TypeVar is encountered.
141-
142-
"""
143-
144-
145-
class InvalidStateError(Exception):
146-
"""Raised when an operation is performed on a future that is not
147-
allowed in the current state.
148-
"""
149-
150-
151-
class _Future(Generic[_U]):
152-
"""The _Future class allows deferred access to a result that is not
153-
yet available.
154-
"""
155-
156-
_done: bool
157-
_result: _U
158-
159-
def __init__(self) -> None:
160-
self._done = False
161-
162-
def done(self) -> bool:
163-
"""Return ``True`` if the value is available"""
164-
return self._done
165-
166-
def result(self) -> _U:
167-
"""Return the deferred value.
168-
169-
Raises ``InvalidStateError`` if the value has not been set.
170-
"""
171-
if self.done():
172-
return self._result
173-
raise InvalidStateError("result has not been set")
174-
175-
def set_result(self, result: _U) -> None:
176-
if self.done():
177-
raise InvalidStateError("result has already been set")
178-
self._result = result
179-
self._done = True
180-
181-
182142
def _check_decorated_type(cls: object) -> None:
183143
if not isinstance(cls, type):
184144
raise TypeError(f"expected a class not {cls!r}")
185-
if _is_generic_alias(cls):
145+
if is_generic_alias(cls):
186146
# A .Schema attribute doesn't make sense on a generic alias — there's
187147
# no way for it to know the generic parameters at run time.
188148
raise TypeError(
@@ -513,9 +473,7 @@ def class_schema(
513473
>>> class_schema(Custom)().load({})
514474
Custom(name=None)
515475
"""
516-
if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass(
517-
clazz
518-
):
476+
if not dataclasses.is_dataclass(clazz) and not is_generic_alias_of_dataclass(clazz):
519477
clazz = dataclasses.dataclass(clazz)
520478
if localns is None:
521479
if clazz_frame is None:
@@ -791,8 +749,16 @@ def _field_for_annotated_type(
791749
marshmallow_annotations = [
792750
arg
793751
for arg in arguments[1:]
794-
if (inspect.isclass(arg) and issubclass(arg, marshmallow.fields.Field))
795-
or isinstance(arg, marshmallow.fields.Field)
752+
if _is_marshmallow_field(arg)
753+
# Support `CustomGenericField[mf.String]`
754+
or (
755+
is_generic_type(arg)
756+
and _is_marshmallow_field(typing_extensions.get_origin(arg))
757+
)
758+
# Support `partial(mf.List, mf.String)`
759+
or (isinstance(arg, partial) and _is_marshmallow_field(arg.func))
760+
# Support `lambda *args, **kwargs: mf.List(mf.String, *args, **kwargs)`
761+
or (_is_callable_marshmallow_field(arg))
796762
]
797763
if marshmallow_annotations:
798764
if len(marshmallow_annotations) > 1:
@@ -932,7 +898,7 @@ def _field_for_schema(
932898

933899
# i.e.: Literal['abc']
934900
if typing_inspect.is_literal_type(typ):
935-
arguments = typing_inspect.get_args(typ)
901+
arguments = typing_extensions.get_args(typ)
936902
return marshmallow.fields.Raw(
937903
validate=(
938904
marshmallow.validate.Equal(arguments[0])
@@ -944,7 +910,7 @@ def _field_for_schema(
944910

945911
# i.e.: Final[str] = 'abc'
946912
if typing_inspect.is_final_type(typ):
947-
arguments = typing_inspect.get_args(typ)
913+
arguments = typing_extensions.get_args(typ)
948914
if arguments:
949915
subtyp = arguments[0]
950916
elif default is not marshmallow.missing:
@@ -1061,14 +1027,14 @@ def _get_field_default(field: dataclasses.Field):
10611027
return field.default
10621028

10631029

1064-
def _is_generic_alias_of_dataclass(clazz: type) -> bool:
1030+
def is_generic_alias_of_dataclass(clazz: type) -> bool:
10651031
"""
10661032
Check if given class is a generic alias of a dataclass, if the dataclass is
10671033
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
10681034
"""
10691035
is_generic = is_generic_type(clazz)
1070-
type_arguments = typing_inspect.get_args(clazz)
1071-
origin_class = typing_inspect.get_origin(clazz)
1036+
type_arguments = typing_extensions.get_args(clazz)
1037+
origin_class = typing_extensions.get_origin(clazz)
10721038
return (
10731039
is_generic
10741040
and len(type_arguments) > 0
@@ -1107,136 +1073,30 @@ class X:
11071073
return _get_type_hints(X, schema_ctx)["x"]
11081074

11091075

1110-
def _is_generic_alias(clazz: type) -> bool:
1111-
"""
1112-
Check if given class is a generic alias of a class is
1113-
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1114-
"""
1115-
is_generic = is_generic_type(clazz)
1116-
type_arguments = typing_inspect.get_args(clazz)
1117-
return is_generic and len(type_arguments) > 0
1118-
1119-
1120-
def is_generic_type(clazz: type) -> bool:
1121-
"""
1122-
typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar
1123-
"""
1124-
return (
1125-
isinstance(clazz, type)
1126-
and issubclass(clazz, Generic) # type: ignore[arg-type]
1127-
or isinstance(clazz, typing_inspect.typingGenericAlias)
1128-
)
1129-
1130-
1131-
def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
1132-
"""
1133-
Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
1134-
1135-
Returns a dict of each base class and the resolved generics.
1136-
"""
1137-
# Use Tuples so can zip (order matters)
1138-
args_by_class: Dict[type, Tuple[Tuple[TypeVar, _Future], ...]] = {}
1139-
parent_class: Optional[type] = None
1140-
# Loop in reversed order and iteratively resolve types
1141-
for subclass in reversed(clazz.mro()):
1142-
if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type]
1143-
args = typing_inspect.get_args(subclass.__orig_bases__[0])
1144-
1145-
if parent_class and args_by_class.get(parent_class):
1146-
subclass_generic_params_to_args: List[Tuple[TypeVar, _Future]] = []
1147-
for (_arg, future), potential_type in zip(
1148-
args_by_class[parent_class], args
1149-
):
1150-
if isinstance(potential_type, TypeVar):
1151-
subclass_generic_params_to_args.append((potential_type, future))
1152-
else:
1153-
future.set_result(potential_type)
1154-
1155-
args_by_class[subclass] = tuple(subclass_generic_params_to_args)
1156-
1157-
else:
1158-
args_by_class[subclass] = tuple((arg, _Future()) for arg in args)
1159-
1160-
parent_class = subclass
1161-
1162-
# clazz itself is a generic alias i.e.: A[int]. So it hold the last types.
1163-
if _is_generic_alias(clazz):
1164-
origin = typing_inspect.get_origin(clazz)
1165-
args = typing_inspect.get_args(clazz)
1166-
for (_arg, future), potential_type in zip(args_by_class[origin], args):
1167-
if not isinstance(potential_type, TypeVar):
1168-
future.set_result(potential_type)
1169-
1170-
# Convert to nested dict for easier lookup
1171-
return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()}
1172-
1173-
1174-
def _replace_typevars(
1175-
clazz: type, resolved_generics: Optional[Dict[TypeVar, _Future]] = None
1176-
) -> type:
1177-
if not resolved_generics or inspect.isclass(clazz) or not is_generic_type(clazz):
1178-
return clazz
1179-
1180-
return clazz.copy_with( # type: ignore
1181-
tuple(
1182-
(
1183-
_replace_typevars(arg, resolved_generics)
1184-
if is_generic_type(arg)
1185-
else (
1186-
resolved_generics[arg].result() if arg in resolved_generics else arg
1187-
)
1188-
)
1189-
for arg in typing_inspect.get_args(clazz)
1190-
)
1191-
)
1192-
1193-
11941076
def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
11951077
if not is_generic_type(clazz):
11961078
return dataclasses.fields(clazz)
11971079

11981080
else:
1199-
unbound_fields = set()
1200-
# Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
1201-
# looses the source class. Thus I don't know how to resolve this at later on.
1202-
# Instead we recreate the type but with all known TypeVars resolved to their actual types.
1203-
resolved_typevars = _resolve_typevars(clazz)
1204-
# Dict[field_name, Tuple[original_field, resolved_field]]
1205-
fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {}
1206-
1207-
for subclass in reversed(clazz.mro()):
1208-
if not dataclasses.is_dataclass(subclass):
1209-
continue
1210-
1211-
for field in dataclasses.fields(subclass):
1212-
try:
1213-
if field.name in fields and fields[field.name][0] == field:
1214-
continue # identical, so already resolved.
1215-
1216-
# Either the first time we see this field, or it got overridden
1217-
# If it's a class we handle it later as a Nested. Nothing to resolve now.
1218-
new_field = field
1219-
if not inspect.isclass(field.type) and is_generic_type(field.type):
1220-
new_field = copy.copy(field)
1221-
new_field.type = _replace_typevars(
1222-
field.type, resolved_typevars[subclass]
1223-
)
1224-
elif isinstance(field.type, TypeVar):
1225-
new_field = copy.copy(field)
1226-
new_field.type = resolved_typevars[subclass][
1227-
field.type
1228-
].result()
1229-
1230-
fields[field.name] = (field, new_field)
1231-
except InvalidStateError:
1232-
unbound_fields.add(field.name)
1233-
1234-
if unbound_fields:
1235-
raise UnboundTypeVarError(
1236-
f"{clazz.__name__} has unbound fields: {', '.join(unbound_fields)}"
1237-
)
1081+
return get_generic_dataclass_fields(clazz)
1082+
1083+
1084+
def _is_marshmallow_field(obj) -> bool:
1085+
return (
1086+
inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field)
1087+
) or isinstance(obj, marshmallow.fields.Field)
1088+
1089+
1090+
def _is_callable_marshmallow_field(obj) -> bool:
1091+
"""Checks if the object is a callable and if the callable returns a marshmallow field"""
1092+
if callable(obj):
1093+
try:
1094+
potential_field = obj()
1095+
return _is_marshmallow_field(potential_field)
1096+
except Exception:
1097+
return False
12381098

1239-
return tuple(v[1] for v in fields.values())
1099+
return False
12401100

12411101

12421102
def NewType(

0 commit comments

Comments
 (0)