3
3
import inspect
4
4
import sys
5
5
from typing import (
6
+ Any ,
6
7
Dict ,
8
+ ForwardRef ,
7
9
Generic ,
8
10
List ,
9
11
Optional ,
15
17
16
18
if sys .version_info >= (3 , 9 ):
17
19
from typing import Annotated , get_args , get_origin
20
+
21
+ def eval_forward_ref (t : ForwardRef , globalns , localns , recursive_guard = frozenset ()):
22
+ return t ._evaluate (globalns , localns , recursive_guard )
23
+
18
24
else :
19
25
from typing_extensions import Annotated , get_args , get_origin
20
26
27
+ def eval_forward_ref (t : ForwardRef , globalns , localns ):
28
+ return t ._evaluate (globalns , localns )
29
+
30
+
21
31
_U = TypeVar ("_U" )
22
32
23
33
@@ -99,7 +109,35 @@ def may_contain_typevars(clazz: type) -> bool:
99
109
)
100
110
101
111
102
- def _resolve_typevars (clazz : type ) -> Dict [type , Dict [TypeVar , _Future ]]:
112
+ def _get_namespaces (
113
+ clazz : type ,
114
+ globalns : Optional [Dict [str , Any ]] = None ,
115
+ localns : Optional [Dict [str , Any ]] = None ,
116
+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
117
+ # region - Copied from typing.get_type_hints
118
+ if globalns is None :
119
+ base_globals = getattr (sys .modules .get (clazz .__module__ , None ), "__dict__" , {})
120
+ else :
121
+ base_globals = globalns
122
+ base_locals = dict (vars (clazz )) if localns is None else localns
123
+ if localns is None and globalns is None :
124
+ # This is surprising, but required. Before Python 3.10,
125
+ # get_type_hints only evaluated the globalns of
126
+ # a class. To maintain backwards compatibility, we reverse
127
+ # the globalns and localns order so that eval() looks into
128
+ # *base_globals* first rather than *base_locals*.
129
+ # This only affects ForwardRefs.
130
+ base_globals , base_locals = base_locals , base_globals
131
+ # endregion - Copied from typing.get_type_hints
132
+
133
+ return base_globals , base_locals
134
+
135
+
136
+ def _resolve_typevars (
137
+ clazz : type ,
138
+ globalns : Optional [Dict [str , Any ]] = None ,
139
+ localns : Optional [Dict [str , Any ]] = None ,
140
+ ) -> Dict [type , Dict [TypeVar , _Future ]]:
103
141
"""
104
142
Attemps to resolves all TypeVars in the class bases. Allows us to resolve inherited and aliased generics.
105
143
@@ -110,6 +148,7 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
110
148
parent_class : Optional [type ] = None
111
149
# Loop in reversed order and iteratively resolve types
112
150
for subclass in reversed (clazz .mro ()):
151
+ base_globals , base_locals = _get_namespaces (subclass , globalns , localns )
113
152
if issubclass (subclass , Generic ) and hasattr (subclass , "__orig_bases__" ): # type: ignore[arg-type]
114
153
args = get_args (subclass .__orig_bases__ [0 ])
115
154
@@ -121,10 +160,17 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
121
160
if isinstance (potential_type , TypeVar ):
122
161
subclass_generic_params_to_args .append ((potential_type , future ))
123
162
else :
124
- future .set_result (potential_type )
163
+ future .set_result (
164
+ eval_forward_ref (
165
+ potential_type ,
166
+ globalns = base_globals ,
167
+ localns = base_locals ,
168
+ )
169
+ if isinstance (potential_type , ForwardRef )
170
+ else potential_type
171
+ )
125
172
126
173
args_by_class [subclass ] = tuple (subclass_generic_params_to_args )
127
-
128
174
else :
129
175
args_by_class [subclass ] = tuple ((arg , _Future ()) for arg in args )
130
176
@@ -136,7 +182,11 @@ def _resolve_typevars(clazz: type) -> Dict[type, Dict[TypeVar, _Future]]:
136
182
args = get_args (clazz )
137
183
for (_arg , future ), potential_type in zip (args_by_class [origin ], args ): # type: ignore[index]
138
184
if not isinstance (potential_type , TypeVar ):
139
- future .set_result (potential_type )
185
+ future .set_result (
186
+ eval_forward_ref (potential_type , globalns = globalns , localns = localns )
187
+ if isinstance (potential_type , ForwardRef )
188
+ else potential_type
189
+ )
140
190
141
191
# Convert to nested dict for easier lookup
142
192
return {k : {typ : fut for typ , fut in args } for k , args in args_by_class .items ()}
@@ -166,12 +216,16 @@ def _replace_typevars(
166
216
)
167
217
168
218
169
- def get_generic_dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
219
+ def get_resolved_dataclass_fields (
220
+ clazz : type ,
221
+ globalns : Optional [Dict [str , Any ]] = None ,
222
+ localns : Optional [Dict [str , Any ]] = None ,
223
+ ) -> Tuple [dataclasses .Field , ...]:
170
224
unbound_fields = set ()
171
225
# Need to manually resolve fields because `dataclasses.fields` doesn't handle generics and
172
226
# looses the source class. Thus I don't know how to resolve this at later on.
173
227
# Instead we recreate the type but with all known TypeVars resolved to their actual types.
174
- resolved_typevars = _resolve_typevars (clazz )
228
+ resolved_typevars = _resolve_typevars (clazz , globalns = globalns , localns = localns )
175
229
# Dict[field_name, Tuple[original_field, resolved_field]]
176
230
fields : Dict [str , Tuple [dataclasses .Field , dataclasses .Field ]] = {}
177
231
@@ -190,14 +244,34 @@ def get_generic_dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
190
244
if not inspect .isclass (field .type ) and may_contain_typevars (field .type ):
191
245
new_field = copy .copy (field )
192
246
new_field .type = _replace_typevars (
193
- field .type , resolved_typevars [ subclass ]
247
+ field .type , resolved_typevars . get ( subclass )
194
248
)
195
249
elif isinstance (field .type , TypeVar ):
196
250
new_field = copy .copy (field )
197
251
new_field .type = resolved_typevars [subclass ][field .type ].result ()
252
+ elif isinstance (field .type , ForwardRef ):
253
+ base_globals , base_locals = _get_namespaces (
254
+ subclass , globalns , localns
255
+ )
256
+ new_field = copy .copy (field )
257
+ new_field .type = eval_forward_ref (
258
+ field .type , globalns = base_globals , localns = base_locals
259
+ )
260
+ elif isinstance (field .type , str ):
261
+ base_globals , base_locals = _get_namespaces (
262
+ subclass , globalns , localns
263
+ )
264
+ new_field = copy .copy (field )
265
+ new_field .type = eval_forward_ref (
266
+ ForwardRef (field .type , is_argument = False , is_class = True )
267
+ if sys .version_info >= (3 , 9 )
268
+ else ForwardRef (field .type , is_argument = False ),
269
+ globalns = base_globals ,
270
+ localns = base_locals ,
271
+ )
198
272
199
273
fields [field .name ] = (field , new_field )
200
- except InvalidStateError :
274
+ except ( InvalidStateError , KeyError ) :
201
275
unbound_fields .add (field .name )
202
276
203
277
if unbound_fields :
0 commit comments