19
19
20
20
import importlib
21
21
from collections import defaultdict
22
- from typing import Any , Dict , List , Optional , Type , Union
22
+ from typing import Any , Dict , List , Optional , TypeVar , Union
23
23
24
24
25
25
__all__ = [
32
32
]
33
33
34
34
35
- _ALIAS_REGISTRY : Dict [Type , Dict [str , str ]] = defaultdict (dict )
36
- _REGISTRY : Dict [Type , Dict [str , Any ]] = defaultdict (dict )
35
+ _ALIAS_REGISTRY : Dict [type , Dict [str , str ]] = defaultdict (dict )
36
+ _REGISTRY : Dict [type , Dict [str , Any ]] = defaultdict (dict )
37
+ T = TypeVar ("" , bound = "RegistryMixin" )
37
38
38
39
39
40
def standardize_lookup_name (name : str ) -> str :
@@ -159,7 +160,7 @@ def register_value(
159
160
)
160
161
161
162
@classmethod
162
- def load_from_registry (cls , name : str , ** constructor_kwargs ) -> object :
163
+ def load_from_registry (cls : type [ T ] , name : str , ** constructor_kwargs ) -> T :
163
164
"""
164
165
:param name: name of registered class to load
165
166
:param constructor_kwargs: arguments to pass to the constructor retrieved
@@ -172,7 +173,7 @@ def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
172
173
return constructor (** constructor_kwargs )
173
174
174
175
@classmethod
175
- def get_value_from_registry (cls , name : str ):
176
+ def get_value_from_registry (cls : type [ T ] , name : str ) -> T :
176
177
"""
177
178
:param name: name to retrieve from the registry
178
179
:return: value from retrieved the registry for the given name, raises
@@ -200,7 +201,7 @@ def registered_aliases(cls) -> List[str]:
200
201
201
202
202
203
def register (
203
- parent_class : Type ,
204
+ parent_class : type ,
204
205
value : Any ,
205
206
name : Optional [str ] = None ,
206
207
alias : Union [List [str ], str , None ] = None ,
@@ -240,7 +241,7 @@ def register(
240
241
241
242
242
243
def get_from_registry (
243
- parent_class : Type , name : str , require_subclass : bool = False
244
+ parent_class : type , name : str , require_subclass : bool = False
244
245
) -> Any :
245
246
"""
246
247
:param parent_class: class that the name is registered under
@@ -276,15 +277,15 @@ def get_from_registry(
276
277
return retrieved_value
277
278
278
279
279
- def registered_names (parent_class : Type ) -> List [str ]:
280
+ def registered_names (parent_class : type ) -> List [str ]:
280
281
"""
281
282
:param parent_class: class to look up the registry of
282
283
:return: all names registered to the given class
283
284
"""
284
285
return list (_REGISTRY [parent_class ].keys ())
285
286
286
287
287
- def registered_aliases (parent_class : Type ) -> List [str ]:
288
+ def registered_aliases (parent_class : type ) -> List [str ]:
288
289
"""
289
290
:param parent_class: class to look up the registry of
290
291
:return: all aliases registered to the given class
@@ -297,7 +298,7 @@ def registered_aliases(parent_class: Type) -> List[str]:
297
298
298
299
299
300
def register_alias (
300
- name : str , parent_class : Type , alias : Union [str , List [str ], None ] = None
301
+ name : str , parent_class : type , alias : Union [str , List [str ], None ] = None
301
302
):
302
303
"""
303
304
Updates the mapping from the alias(es) to the given name.
@@ -352,7 +353,7 @@ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
352
353
return value
353
354
354
355
355
- def _validate_subclass (parent_class : Type , child_class : Type ):
356
+ def _validate_subclass (parent_class : type , child_class : type ):
356
357
if not issubclass (child_class , parent_class ):
357
358
raise ValueError (
358
359
f"class { child_class } is not a subclass of the class it is "
0 commit comments