Skip to content

Commit 4ef0bdf

Browse files
committed
Remove the need to call get_type_hints
I don't want to loop over the fields multiple times so internalized the relevant code from typing.get_type_hints into the generic_resolver
1 parent 8797b2b commit 4ef0bdf

File tree

2 files changed

+87
-61
lines changed

2 files changed

+87
-61
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ class User:
6161
TypeVar,
6262
Union,
6363
cast,
64-
get_type_hints,
6564
overload,
6665
)
6766

@@ -70,7 +69,7 @@ class User:
7069

7170
from marshmallow_dataclass.generic_resolver import (
7271
UnboundTypeVarError,
73-
get_generic_dataclass_fields,
72+
get_resolved_dataclass_fields,
7473
is_generic_alias,
7574
)
7675
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute
@@ -548,7 +547,9 @@ def _internal_class_schema(
548547
schema_ctx.seen_classes[clazz] = class_name
549548

550549
try:
551-
fields = _dataclass_fields(clazz)
550+
fields = get_resolved_dataclass_fields(
551+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
552+
)
552553
except UnboundTypeVarError:
553554
raise
554555
except TypeError: # Not a dataclass
@@ -584,19 +585,11 @@ def _internal_class_schema(
584585
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)
585586

586587
# Update the schema members to contain marshmallow fields instead of dataclass fields
587-
type_hints = {}
588-
if not typing_inspect.is_generic_type(clazz):
589-
type_hints = _get_type_hints(clazz, schema_ctx)
590-
591588
attributes.update(
592589
(
593590
field.name,
594591
_field_for_schema(
595-
(
596-
type_hints[field.name]
597-
if not typing_inspect.is_generic_type(clazz)
598-
else _resolve_forward_type_refs(field.type, schema_ctx)
599-
),
592+
field.type,
600593
_get_field_default(field),
601594
field.metadata,
602595
base_schema,
@@ -1037,47 +1030,6 @@ def is_generic_alias_of_dataclass(clazz: type) -> bool:
10371030
return is_generic_alias(clazz) and dataclasses.is_dataclass(get_origin(clazz))
10381031

10391032

1040-
def _get_type_hints(
1041-
obj,
1042-
schema_ctx: _SchemaContext,
1043-
):
1044-
if sys.version_info >= (3, 9):
1045-
type_hints = get_type_hints(
1046-
obj,
1047-
globalns=schema_ctx.globalns,
1048-
localns=schema_ctx.localns,
1049-
include_extras=True,
1050-
)
1051-
else:
1052-
type_hints = get_type_hints(
1053-
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
1054-
)
1055-
1056-
return type_hints
1057-
1058-
1059-
def _resolve_forward_type_refs(
1060-
obj,
1061-
schema_ctx: _SchemaContext,
1062-
) -> type:
1063-
"""
1064-
Resolve forward references, mainly applies to Generics i.e.: `A["int"]` -> `A[int]`
1065-
"""
1066-
1067-
class X:
1068-
x: obj # type: ignore[name-defined]
1069-
1070-
return _get_type_hints(X, schema_ctx)["x"]
1071-
1072-
1073-
def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
1074-
if not typing_inspect.is_generic_type(clazz):
1075-
return dataclasses.fields(clazz)
1076-
1077-
else:
1078-
return get_generic_dataclass_fields(clazz)
1079-
1080-
10811033
def _is_marshmallow_field(obj) -> bool:
10821034
return (
10831035
inspect.isclass(obj) and issubclass(obj, marshmallow.fields.Field)

marshmallow_dataclass/generic_resolver.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import inspect
44
import sys
55
from typing import (
6+
Any,
67
Dict,
8+
ForwardRef,
79
Generic,
810
List,
911
Optional,
@@ -15,9 +17,17 @@
1517

1618
if sys.version_info >= (3, 9):
1719
from typing import Annotated, get_args, get_origin
20+
21+
def eval_forward_ref(t: ForwardRef, globalns, localns, recursive_guard=frozenset()):
22+
return t._evaluate(globalns, localns, recursive_guard)
23+
1824
else:
1925
from typing_extensions import Annotated, get_args, get_origin
2026

27+
def eval_forward_ref(t: ForwardRef, globalns, localns):
28+
return t._evaluate(globalns, localns)
29+
30+
2131
_U = TypeVar("_U")
2232

2333

@@ -99,7 +109,35 @@ def may_contain_typevars(clazz: type) -> bool:
99109
)
100110

101111

102-
def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
112+
def _get_namespaces(
113+
clazz: type,
114+
globalns: Optional[Dict[str, Any]] = None,
115+
localns: Optional[Dict[str, Any]] = None,
116+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
117+
# region - Copied from typing.get_type_hints
118+
if globalns is None:
119+
base_globals = getattr(sys.modules.get(clazz.__module__, None), "__dict__", {})
120+
else:
121+
base_globals = globalns
122+
base_locals = dict(vars(clazz)) if localns is None else localns
123+
if localns is None and globalns is None:
124+
# This is surprising, but required. Before Python 3.10,
125+
# get_type_hints only evaluated the globalns of
126+
# a class. To maintain backwards compatibility, we reverse
127+
# the globalns and localns order so that eval() looks into
128+
# *base_globals* first rather than *base_locals*.
129+
# This only affects ForwardRefs.
130+
base_globals, base_locals = base_locals, base_globals
131+
# endregion - Copied from typing.get_type_hints
132+
133+
return base_globals, base_locals
134+
135+
136+
def _resolve_typevars(
137+
clazz: type,
138+
globalns: Optional[Dict[str, Any]] = None,
139+
localns: Optional[Dict[str, Any]] = None,
140+
) -> Dict[type, Dict[TypeVar, _Future]]:
103141
"""
104142
Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
105143
@@ -110,6 +148,7 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
110148
parent_class: Optional[type] = None
111149
# Loop in reversed order and iteratively resolve types
112150
for subclass in reversed(clazz.mro()):
151+
base_globals, base_locals = _get_namespaces(subclass, globalns, localns)
113152
if issubclass(subclass, Generic) and hasattr(subclass, "__orig_bases__"): # type: ignore[arg-type]
114153
args = get_args(subclass.__orig_bases__[0])
115154

@@ -121,10 +160,17 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
121160
if isinstance(potential_type, TypeVar):
122161
subclass_generic_params_to_args.append((potential_type, future))
123162
else:
124-
future.set_result(potential_type)
163+
future.set_result(
164+
eval_forward_ref(
165+
potential_type,
166+
globalns=base_globals,
167+
localns=base_locals,
168+
)
169+
if isinstance(potential_type, ForwardRef)
170+
else potential_type
171+
)
125172

126173
args_by_class[subclass] = tuple(subclass_generic_params_to_args)
127-
128174
else:
129175
args_by_class[subclass] = tuple((arg, _Future()) for arg in args)
130176

@@ -136,7 +182,11 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
136182
args = get_args(clazz)
137183
for (_arg, future), potential_type in zip(args_by_class[origin], args): # type: ignore[index]
138184
if not isinstance(potential_type, TypeVar):
139-
future.set_result(potential_type)
185+
future.set_result(
186+
eval_forward_ref(potential_type, globalns=globalns, localns=localns)
187+
if isinstance(potential_type, ForwardRef)
188+
else potential_type
189+
)
140190

141191
# Convert to nested dict for easier lookup
142192
return {k: {typ: fut for typ, fut in args} for k, args in args_by_class.items()}
@@ -166,12 +216,16 @@ def _replace_typevars(
166216
)
167217

168218

169-
def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
219+
def get_resolved_dataclass_fields(
220+
clazz: type,
221+
globalns: Optional[Dict[str, Any]] = None,
222+
localns: Optional[Dict[str, Any]] = None,
223+
) -> Tuple[dataclasses.Field, ...]:
170224
unbound_fields = set()
171225
# Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
172226
# looses the source class. Thus I don't know how to resolve this at later on.
173227
# Instead we recreate the type but with all known TypeVars resolved to their actual types.
174-
resolved_typevars = _resolve_typevars(clazz)
228+
resolved_typevars = _resolve_typevars(clazz, globalns=globalns, localns=localns)
175229
# Dict[field_name, Tuple[original_field, resolved_field]]
176230
fields: Dict[str, Tuple[dataclasses.Field, dataclasses.Field]] = {}
177231

@@ -190,14 +244,34 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
190244
if not inspect.isclass(field.type) and may_contain_typevars(field.type):
191245
new_field = copy.copy(field)
192246
new_field.type = _replace_typevars(
193-
field.type, resolved_typevars[subclass]
247+
field.type, resolved_typevars.get(subclass)
194248
)
195249
elif isinstance(field.type, TypeVar):
196250
new_field = copy.copy(field)
197251
new_field.type = resolved_typevars[subclass][field.type].result()
252+
elif isinstance(field.type, ForwardRef):
253+
base_globals, base_locals = _get_namespaces(
254+
subclass, globalns, localns
255+
)
256+
new_field = copy.copy(field)
257+
new_field.type = eval_forward_ref(
258+
field.type, globalns=base_globals, localns=base_locals
259+
)
260+
elif isinstance(field.type, str):
261+
base_globals, base_locals = _get_namespaces(
262+
subclass, globalns, localns
263+
)
264+
new_field = copy.copy(field)
265+
new_field.type = eval_forward_ref(
266+
ForwardRef(field.type, is_argument=False, is_class=True)
267+
if sys.version_info >= (3, 9)
268+
else ForwardRef(field.type, is_argument=False),
269+
globalns=base_globals,
270+
localns=base_locals,
271+
)
198272

199273
fields[field.name] = (field, new_field)
200-
except InvalidStateError:
274+
except (InvalidStateError, KeyError):
201275
unbound_fields.add(field.name)
202276

203277
if unbound_fields:

0 commit comments

Comments
 (0)