Skip to content

Commit ac667c4

Browse files
Add docstrings to kimm.list_models, kimm.utils.get_reparameterized_model (#49)
1 parent 7de107b commit ac667c4

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

kimm/_src/utils/model_registry.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,23 @@ def list_models(
7878
name: typing.Optional[str] = None,
7979
feature_extractor: typing.Optional[bool] = None,
8080
weights: typing.Optional[typing.Union[bool, str]] = None,
81-
) -> typing.List[str]:
82-
result_names: typing.Set = set()
81+
):
82+
"""List the models with the given arguments.
83+
84+
Args:
85+
name: An optional `str` specifying the substring of the name of the
86+
model to seatch for. If not specified, all models will be included.
87+
feature_extractor: Whether to include models that support
88+
feature extraction. Defaults to `None`, which means this
89+
argument is not considered.
90+
weights: An optional boolean or `str` specifying the name of the
91+
pretrained weights. The available values are (`"imagenet"`).
92+
Defaults to `None`, which means this argument is not considered.
93+
94+
Returns:
95+
A list of model names.
96+
"""
97+
result_names: typing.Set[str] = set()
8398
for info in MODEL_REGISTRY:
8499
# Add by default
85100
result_names.add(info["name"])

kimm/_src/utils/model_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44

55
@kimm_export(parent_path=["kimm.utils"])
66
def get_reparameterized_model(model: BaseModel):
7+
"""Get the reparameterized model.
8+
9+
Internally, this function calls `get_reparameterized_model` from the
10+
provided `model`.
11+
12+
Args:
13+
model: A `BaseModel` to convert to its reparameterized form.
14+
15+
Returns:
16+
An instance of the same class as `model` in its reparameterized form.
17+
"""
718
if not hasattr(model, "get_reparameterized_model"):
819
raise ValueError(
920
"There is no 'get_reparameterized_model' method in the model. "

0 commit comments

Comments
 (0)