Skip to content

Commit ea5ed32

Browse files
committed
docs: docstr fixes
1 parent 238598b commit ea5ed32

File tree

4 files changed

+39
-17
lines changed

4 files changed

+39
-17
lines changed

cellseg_models_pytorch/models/__init__.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch.nn as nn
2+
13
from .base._multitask_unet import MultiTaskUnet
24
from .cellpose.cellpose import (
35
CellPoseUnet,
@@ -26,6 +28,7 @@
2628
"omnipose_base": omnipose_base,
2729
"omnipose_plus": omnipose_plus,
2830
"hovernet_base": hovernet_base,
31+
"hovernet_plus": hovernet_plus,
2932
"hovernet_small": hovernet_small,
3033
"hovernet_small_plus": hovernet_small_plus,
3134
"stardist_base": stardist_base,
@@ -34,43 +37,63 @@
3437
}
3538

3639

37-
def get_model(name: str, type: str, ntypes: int = None, ntissues: int = None):
38-
"""Get the corect model at hand given name and type."""
40+
def get_model(
41+
name: str, type: str, ntypes: int = None, ntissues: int = None, **kwargs
42+
) -> nn.Module:
43+
"""Get the corect model at hand given name and type.
44+
45+
Parameters
46+
----------
47+
name : str
48+
Name of the model.
49+
type : str
50+
Type of the model. One of "base", "plus", "small", "small_plus".
51+
ntypes : int
52+
Number of cell types to segment.
53+
ntissues : int
54+
Number of tissue types to segment.
55+
**kwargs : dict
56+
Additional keyword arguments.
57+
58+
Returns
59+
-------
60+
nn.Module: The specified model.
61+
"""
3962
if name == "stardist":
4063
if type == "base":
4164
model = MODEL_LOOKUP["stardist_base_multiclass"](
42-
n_rays=32, type_classes=ntypes
65+
n_rays=32, type_classes=ntypes, **kwargs
4366
)
4467
elif type == "plus":
4568
model = MODEL_LOOKUP["stardist_plus"](
46-
n_rays=32, type_classes=ntypes, sem_classes=ntissues
69+
n_rays=32, type_classes=ntypes, sem_classes=ntissues, **kwargs
4770
)
4871
elif name == "cellpose":
4972
if type == "base":
50-
model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes)
73+
model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes, **kwargs)
5174
elif type == "plus":
5275
model = MODEL_LOOKUP["cellpose_plus"](
53-
type_classes=ntypes, sem_classes=ntissues
76+
type_classes=ntypes, sem_classes=ntissues, **kwargs
5477
)
5578
elif name == "omnipose":
5679
if type == "base":
57-
model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes)
80+
model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes, **kwargs)
5881
elif type == "plus":
5982
model = MODEL_LOOKUP["omnipose_plus"](
60-
type_classes=ntypes, sem_classes=ntissues
83+
type_classes=ntypes, sem_classes=ntissues, **kwargs
6184
)
6285
elif name == "hovernet":
6386
if type == "base":
64-
model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes)
87+
model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes, **kwargs)
6588
elif type == "small":
66-
model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes)
89+
model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes, **kwargs)
6790
elif type == "plus":
6891
model = MODEL_LOOKUP["hovernet_plus"](
69-
type_classes=ntypes, sem_classes=ntissues
92+
type_classes=ntypes, sem_classes=ntissues, **kwargs
7093
)
7194
elif type == "small_plus":
7295
model = MODEL_LOOKUP["hovernet_small_plus"](
73-
type_classes=ntypes, sem_classes=ntissues
96+
type_classes=ntypes, sem_classes=ntissues, **kwargs
7497
)
7598
else:
7699
raise ValueError("Unknown model type or name.")

cellseg_models_pytorch/modules/transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ def __init__(
243243
f"Illegal args: {illegal_args}"
244244
)
245245

246-
# self.tr_blocks = nn.ModuleDict()
247246
self.tr_blocks = nn.ModuleList()
248247
self.layer_scales = nn.ModuleList()
249248
blocks = list(range(n_blocks))

cellseg_models_pytorch/training/lit/lightning_experiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
Use underscores to create joint loss functions. e.g. "dice_ce_tversky".
5050
branch_loss_params : Dict[str, Dict[str, Any]], optional
5151
Params for the different losses at different branches. For example to
52-
use LS or class weighting when computing the losses.
52+
use label smoothing or class weighting when computing the losses.
5353
E.g. {"inst": {"apply_ls": True}, "sem": {"edge_weight": False}}
5454
branch_metrics : Dict[str, List[str]], optional
5555
A Dict of branch names mapped to a list of strings specifying a metrics.
@@ -65,7 +65,7 @@ def __init__(
6565
optim_params : Dict[str, Dict[str, Any]]
6666
optim paramas like learning rates, weight decays etc for diff parts of
6767
the network.
68-
E.g. {"encoder": {"weight_decay: 0.1, "lr": 0.1}, "sem": {"lr": 0.01}}
68+
E.g. {"encoder": {"weight_decay": 0.1, "lr": 0.1}, "sem": {"lr": 0.01}}
6969
or {"learning_rate": 0.005, "weight_decay": 0.03}
7070
lookahead : bool, default=False
7171
Flag whether the optimizer uses lookahead.

cellseg_models_pytorch/utils/file_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,8 @@ def write_gson(
469469
import geojson
470470
except ModuleNotFoundError:
471471
raise ModuleNotFoundError(
472-
"To use the `FileHandler.mask2geojson`, pytorch-lightning is required. "
473-
"Install with `pip install pytorch-lightning`"
472+
"To use the `FileHandler.mask2geojson`, geojson is required. "
473+
"Install with `pip install geojson`"
474474
)
475475

476476
fname = Path(fname)

0 commit comments

Comments
 (0)