Skip to content

Commit 1b6287a

Browse files
authored
[Misc] Generics typehinting for RegistryMixin (#320)
* typehint Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * use type builtin, as typing.Type is deprecated Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 8916411 commit 1b6287a

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/compressed_tensors/registry/registry.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import importlib
2121
from collections import defaultdict
22-
from typing import Any, Dict, List, Optional, Type, Union
22+
from typing import Any, Dict, List, Optional, TypeVar, Union
2323

2424

2525
__all__ = [
@@ -32,8 +32,9 @@
3232
]
3333

3434

35-
_ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
36-
_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
35+
_ALIAS_REGISTRY: Dict[type, Dict[str, str]] = defaultdict(dict)
36+
_REGISTRY: Dict[type, Dict[str, Any]] = defaultdict(dict)
37+
T = TypeVar("", bound="RegistryMixin")
3738

3839

3940
def standardize_lookup_name(name: str) -> str:
@@ -159,7 +160,7 @@ def register_value(
159160
)
160161

161162
@classmethod
162-
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
163+
def load_from_registry(cls: type[T], name: str, **constructor_kwargs) -> T:
163164
"""
164165
:param name: name of registered class to load
165166
:param constructor_kwargs: arguments to pass to the constructor retrieved
@@ -172,7 +173,7 @@ def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
172173
return constructor(**constructor_kwargs)
173174

174175
@classmethod
175-
def get_value_from_registry(cls, name: str):
176+
def get_value_from_registry(cls: type[T], name: str) -> T:
176177
"""
177178
:param name: name to retrieve from the registry
178179
:return: value from retrieved the registry for the given name, raises
@@ -200,7 +201,7 @@ def registered_aliases(cls) -> List[str]:
200201

201202

202203
def register(
203-
parent_class: Type,
204+
parent_class: type,
204205
value: Any,
205206
name: Optional[str] = None,
206207
alias: Union[List[str], str, None] = None,
@@ -240,7 +241,7 @@ def register(
240241

241242

242243
def get_from_registry(
243-
parent_class: Type, name: str, require_subclass: bool = False
244+
parent_class: type, name: str, require_subclass: bool = False
244245
) -> Any:
245246
"""
246247
:param parent_class: class that the name is registered under
@@ -276,15 +277,15 @@ def get_from_registry(
276277
return retrieved_value
277278

278279

279-
def registered_names(parent_class: Type) -> List[str]:
280+
def registered_names(parent_class: type) -> List[str]:
280281
"""
281282
:param parent_class: class to look up the registry of
282283
:return: all names registered to the given class
283284
"""
284285
return list(_REGISTRY[parent_class].keys())
285286

286287

287-
def registered_aliases(parent_class: Type) -> List[str]:
288+
def registered_aliases(parent_class: type) -> List[str]:
288289
"""
289290
:param parent_class: class to look up the registry of
290291
:return: all aliases registered to the given class
@@ -297,7 +298,7 @@ def registered_aliases(parent_class: Type) -> List[str]:
297298

298299

299300
def register_alias(
300-
name: str, parent_class: Type, alias: Union[str, List[str], None] = None
301+
name: str, parent_class: type, alias: Union[str, List[str], None] = None
301302
):
302303
"""
303304
Updates the mapping from the alias(es) to the given name.
@@ -352,7 +353,7 @@ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
352353
return value
353354

354355

355-
def _validate_subclass(parent_class: Type, child_class: Type):
356+
def _validate_subclass(parent_class: type, child_class: type):
356357
if not issubclass(child_class, parent_class):
357358
raise ValueError(
358359
f"class {child_class} is not a subclass of the class it is "

0 commit comments

Comments
 (0)