@@ -36,7 +36,6 @@ class User:
36
36
"""
37
37
38
38
import collections .abc
39
- import copy
40
39
import dataclasses
41
40
import inspect
42
41
import sys
@@ -64,6 +63,12 @@ class User:
64
63
import typing_extensions
65
64
import typing_inspect
66
65
66
+ from marshmallow_dataclass .generic_resolver import (
67
+ UnboundTypeVarError ,
68
+ get_generic_dataclass_fields ,
69
+ is_generic_alias ,
70
+ is_generic_type ,
71
+ )
67
72
from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
68
73
69
74
if sys .version_info >= (3 , 9 ):
@@ -134,55 +139,10 @@ def _maybe_get_callers_frame(
134
139
del frame
135
140
136
141
137
- class UnboundTypeVarError (TypeError ):
138
- """TypeVar instance can not be resolved to a type spec.
139
-
140
- This exception is raised when an unbound TypeVar is encountered.
141
-
142
- """
143
-
144
-
145
- class InvalidStateError (Exception ):
146
- """Raised when an operation is performed on a future that is not
147
- allowed in the current state.
148
- """
149
-
150
-
151
- class _Future (Generic [_U ]):
152
- """The _Future class allows deferred access to a result that is not
153
- yet available.
154
- """
155
-
156
- _done : bool
157
- _result : _U
158
-
159
- def __init__ (self ) -> None :
160
- self ._done = False
161
-
162
- def done (self ) -> bool :
163
- """Return ``True`` if the value is available"""
164
- return self ._done
165
-
166
- def result (self ) -> _U :
167
- """Return the deferred value.
168
-
169
- Raises ``InvalidStateError`` if the value has not been set.
170
- """
171
- if self .done ():
172
- return self ._result
173
- raise InvalidStateError ("result has not been set" )
174
-
175
- def set_result (self , result : _U ) -> None :
176
- if self .done ():
177
- raise InvalidStateError ("result has already been set" )
178
- self ._result = result
179
- self ._done = True
180
-
181
-
182
142
def _check_decorated_type (cls : object ) -> None :
183
143
if not isinstance (cls , type ):
184
144
raise TypeError (f"expected a class not { cls !r} " )
185
- if _is_generic_alias (cls ):
145
+ if is_generic_alias (cls ):
186
146
# A .Schema attribute doesn't make sense on a generic alias — there's
187
147
# no way for it to know the generic parameters at run time.
188
148
raise TypeError (
@@ -513,9 +473,7 @@ def class_schema(
513
473
>>> class_schema(Custom)().load({})
514
474
Custom(name=None)
515
475
"""
516
- if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
517
- clazz
518
- ):
476
+ if not dataclasses .is_dataclass (clazz ) and not is_generic_alias_of_dataclass (clazz ):
519
477
clazz = dataclasses .dataclass (clazz )
520
478
if localns is None :
521
479
if clazz_frame is None :
@@ -791,8 +749,16 @@ def _field_for_annotated_type(
791
749
marshmallow_annotations = [
792
750
arg
793
751
for arg in arguments [1 :]
794
- if (inspect .isclass (arg ) and issubclass (arg , marshmallow .fields .Field ))
795
- or isinstance (arg , marshmallow .fields .Field )
752
+ if _is_marshmallow_field (arg )
753
+ # Support `CustomGenericField[mf.String]`
754
+ or (
755
+ is_generic_type (arg )
756
+ and _is_marshmallow_field (typing_extensions .get_origin (arg ))
757
+ )
758
+ # Support `partial(mf.List, mf.String)`
759
+ or (isinstance (arg , partial ) and _is_marshmallow_field (arg .func ))
760
+ # Support `lambda *args, **kwargs: mf.List(mf.String, *args, **kwargs)`
761
+ or (_is_callable_marshmallow_field (arg ))
796
762
]
797
763
if marshmallow_annotations :
798
764
if len (marshmallow_annotations ) > 1 :
@@ -932,7 +898,7 @@ def _field_for_schema(
932
898
933
899
# i.e.: Literal['abc']
934
900
if typing_inspect .is_literal_type (typ ):
935
- arguments = typing_inspect .get_args (typ )
901
+ arguments = typing_extensions .get_args (typ )
936
902
return marshmallow .fields .Raw (
937
903
validate = (
938
904
marshmallow .validate .Equal (arguments [0 ])
@@ -944,7 +910,7 @@ def _field_for_schema(
944
910
945
911
# i.e.: Final[str] = 'abc'
946
912
if typing_inspect .is_final_type (typ ):
947
- arguments = typing_inspect .get_args (typ )
913
+ arguments = typing_extensions .get_args (typ )
948
914
if arguments :
949
915
subtyp = arguments [0 ]
950
916
elif default is not marshmallow .missing :
@@ -1061,14 +1027,14 @@ def _get_field_default(field: dataclasses.Field):
1061
1027
return field .default
1062
1028
1063
1029
1064
- def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
1030
+ def is_generic_alias_of_dataclass (clazz : type ) -> bool :
1065
1031
"""
1066
1032
Check if given class is a generic alias of a dataclass, if the dataclass is
1067
1033
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1068
1034
"""
1069
1035
is_generic = is_generic_type (clazz )
1070
- type_arguments = typing_inspect .get_args (clazz )
1071
- origin_class = typing_inspect .get_origin (clazz )
1036
+ type_arguments = typing_extensions .get_args (clazz )
1037
+ origin_class = typing_extensions .get_origin (clazz )
1072
1038
return (
1073
1039
is_generic
1074
1040
and len (type_arguments ) > 0
@@ -1107,136 +1073,30 @@ class X:
1107
1073
return _get_type_hints (X , schema_ctx )["x" ]
1108
1074
1109
1075
1110
- def _is_generic_alias (clazz : type ) -> bool :
1111
- """
1112
- Check if given class is a generic alias of a class is
1113
- defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1114
- """
1115
- is_generic = is_generic_type (clazz )
1116
- type_arguments = typing_inspect .get_args (clazz )
1117
- return is_generic and len (type_arguments ) > 0
1118
-
1119
-
1120
- def is_generic_type (clazz : type ) -> bool :
1121
- """
1122
- typing_inspect.is_generic_type explicitly ignores Union, Tuple, Callable, ClassVar
1123
- """
1124
- return (
1125
- isinstance (clazz , type )
1126
- and issubclass (clazz , Generic ) # type: ignore[arg-type]
1127
- or isinstance (clazz , typing_inspect .typingGenericAlias )
1128
- )
1129
-
1130
-
1131
- def _resolve_typevars (clazz : type ) -> Dict [type , Dict [TypeVar , _Future ]]:
1132
- """
1133
- Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
1134
-
1135
- Returns a dict of each base class and the resolved generics.
1136
- """
1137
- # Use Tuples so can zip (order matters)
1138
- args_by_class : Dict [type , Tuple [Tuple [TypeVar , _Future ], ...]] = {}
1139
- parent_class : Optional [type ] = None
1140
- # Loop in reversed order and iteratively resolve types
1141
- for subclass in reversed (clazz .mro ()):
1142
- if issubclass (subclass , Generic ) and hasattr (subclass , "__orig_bases__" ): # type: ignore[arg-type]
1143
- args = typing_inspect .get_args (subclass .__orig_bases__ [0 ])
1144
-
1145
- if parent_class and args_by_class .get (parent_class ):
1146
- subclass_generic_params_to_args : List [Tuple [TypeVar , _Future ]] = []
1147
- for (_arg , future ), potential_type in zip (
1148
- args_by_class [parent_class ], args
1149
- ):
1150
- if isinstance (potential_type , TypeVar ):
1151
- subclass_generic_params_to_args .append ((potential_type , future ))
1152
- else :
1153
- future .set_result (potential_type )
1154
-
1155
- args_by_class [subclass ] = tuple (subclass_generic_params_to_args )
1156
-
1157
- else :
1158
- args_by_class [subclass ] = tuple ((arg , _Future ()) for arg in args )
1159
-
1160
- parent_class = subclass
1161
-
1162
- # clazz itself is a generic alias i.e.: A[int]. So it hold the last types.
1163
- if _is_generic_alias (clazz ):
1164
- origin = typing_inspect .get_origin (clazz )
1165
- args = typing_inspect .get_args (clazz )
1166
- for (_arg , future ), potential_type in zip (args_by_class [origin ], args ):
1167
- if not isinstance (potential_type , TypeVar ):
1168
- future .set_result (potential_type )
1169
-
1170
- # Convert to nested dict for easier lookup
1171
- return {k : {typ : fut for typ , fut in args } for k , args in args_by_class .items ()}
1172
-
1173
-
1174
- def _replace_typevars (
1175
- clazz : type , resolved_generics : Optional [Dict [TypeVar , _Future ]] = None
1176
- ) -> type :
1177
- if not resolved_generics or inspect .isclass (clazz ) or not is_generic_type (clazz ):
1178
- return clazz
1179
-
1180
- return clazz .copy_with ( # type: ignore
1181
- tuple (
1182
- (
1183
- _replace_typevars (arg , resolved_generics )
1184
- if is_generic_type (arg )
1185
- else (
1186
- resolved_generics [arg ].result () if arg in resolved_generics else arg
1187
- )
1188
- )
1189
- for arg in typing_inspect .get_args (clazz )
1190
- )
1191
- )
1192
-
1193
-
1194
1076
def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
1195
1077
if not is_generic_type (clazz ):
1196
1078
return dataclasses .fields (clazz )
1197
1079
1198
1080
else :
1199
- unbound_fields = set ()
1200
- # Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
1201
- # looses the source class. Thus I don't know how to resolve this at later on.
1202
- # Instead we recreate the type but with all known TypeVars resolved to their actual types.
1203
- resolved_typevars = _resolve_typevars (clazz )
1204
- # Dict[field_name, Tuple[original_field, resolved_field]]
1205
- fields : Dict [str , Tuple [dataclasses .Field , dataclasses .Field ]] = {}
1206
-
1207
- for subclass in reversed (clazz .mro ()):
1208
- if not dataclasses .is_dataclass (subclass ):
1209
- continue
1210
-
1211
- for field in dataclasses .fields (subclass ):
1212
- try :
1213
- if field .name in fields and fields [field .name ][0 ] == field :
1214
- continue # identical, so already resolved.
1215
-
1216
- # Either the first time we see this field, or it got overridden
1217
- # If it's a class we handle it later as a Nested. Nothing to resolve now.
1218
- new_field = field
1219
- if not inspect .isclass (field .type ) and is_generic_type (field .type ):
1220
- new_field = copy .copy (field )
1221
- new_field .type = _replace_typevars (
1222
- field .type , resolved_typevars [subclass ]
1223
- )
1224
- elif isinstance (field .type , TypeVar ):
1225
- new_field = copy .copy (field )
1226
- new_field .type = resolved_typevars [subclass ][
1227
- field .type
1228
- ].result ()
1229
-
1230
- fields [field .name ] = (field , new_field )
1231
- except InvalidStateError :
1232
- unbound_fields .add (field .name )
1233
-
1234
- if unbound_fields :
1235
- raise UnboundTypeVarError (
1236
- f"{ clazz .__name__ } has unbound fields: { ', ' .join (unbound_fields )} "
1237
- )
1081
+ return get_generic_dataclass_fields (clazz )
1082
+
1083
+
1084
+ def _is_marshmallow_field (obj ) -> bool :
1085
+ return (
1086
+ inspect .isclass (obj ) and issubclass (obj , marshmallow .fields .Field )
1087
+ ) or isinstance (obj , marshmallow .fields .Field )
1088
+
1089
+
1090
+ def _is_callable_marshmallow_field (obj ) -> bool :
1091
+ """Checks if the object is a callable and if the callable returns a marshmallow field"""
1092
+ if callable (obj ):
1093
+ try :
1094
+ potential_field = obj ()
1095
+ return _is_marshmallow_field (potential_field )
1096
+ except Exception :
1097
+ return False
1238
1098
1239
- return tuple ( v [ 1 ] for v in fields . values ())
1099
+ return False
1240
1100
1241
1101
1242
1102
def NewType (
0 commit comments