|
| 1 | +import torch.nn as nn |
| 2 | + |
1 | 3 | from .base._multitask_unet import MultiTaskUnet
|
2 | 4 | from .cellpose.cellpose import (
|
3 | 5 | CellPoseUnet,
|
|
26 | 28 | "omnipose_base": omnipose_base,
|
27 | 29 | "omnipose_plus": omnipose_plus,
|
28 | 30 | "hovernet_base": hovernet_base,
|
| 31 | + "hovernet_plus": hovernet_plus, |
29 | 32 | "hovernet_small": hovernet_small,
|
30 | 33 | "hovernet_small_plus": hovernet_small_plus,
|
31 | 34 | "stardist_base": stardist_base,
|
|
34 | 37 | }
|
35 | 38 |
|
36 | 39 |
|
37 |
| -def get_model(name: str, type: str, ntypes: int = None, ntissues: int = None): |
38 |
| - """Get the corect model at hand given name and type.""" |
| 40 | +def get_model( |
| 41 | + name: str, type: str, ntypes: int = None, ntissues: int = None, **kwargs |
| 42 | +) -> nn.Module: |
| 43 | + """Get the corect model at hand given name and type. |
| 44 | +
|
| 45 | + Parameters |
| 46 | + ---------- |
| 47 | + name : str |
| 48 | + Name of the model. |
| 49 | + type : str |
| 50 | + Type of the model. One of "base", "plus", "small", "small_plus". |
| 51 | + ntypes : int |
| 52 | + Number of cell types to segment. |
| 53 | + ntissues : int |
| 54 | + Number of tissue types to segment. |
| 55 | + **kwargs : dict |
| 56 | + Additional keyword arguments. |
| 57 | +
|
| 58 | + Returns |
| 59 | + ------- |
| 60 | + nn.Module: The specified model. |
| 61 | + """ |
39 | 62 | if name == "stardist":
|
40 | 63 | if type == "base":
|
41 | 64 | model = MODEL_LOOKUP["stardist_base_multiclass"](
|
42 |
| - n_rays=32, type_classes=ntypes |
| 65 | + n_rays=32, type_classes=ntypes, **kwargs |
43 | 66 | )
|
44 | 67 | elif type == "plus":
|
45 | 68 | model = MODEL_LOOKUP["stardist_plus"](
|
46 |
| - n_rays=32, type_classes=ntypes, sem_classes=ntissues |
| 69 | + n_rays=32, type_classes=ntypes, sem_classes=ntissues, **kwargs |
47 | 70 | )
|
48 | 71 | elif name == "cellpose":
|
49 | 72 | if type == "base":
|
50 |
| - model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes) |
| 73 | + model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes, **kwargs) |
51 | 74 | elif type == "plus":
|
52 | 75 | model = MODEL_LOOKUP["cellpose_plus"](
|
53 |
| - type_classes=ntypes, sem_classes=ntissues |
| 76 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
54 | 77 | )
|
55 | 78 | elif name == "omnipose":
|
56 | 79 | if type == "base":
|
57 |
| - model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes) |
| 80 | + model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes, **kwargs) |
58 | 81 | elif type == "plus":
|
59 | 82 | model = MODEL_LOOKUP["omnipose_plus"](
|
60 |
| - type_classes=ntypes, sem_classes=ntissues |
| 83 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
61 | 84 | )
|
62 | 85 | elif name == "hovernet":
|
63 | 86 | if type == "base":
|
64 |
| - model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes) |
| 87 | + model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes, **kwargs) |
65 | 88 | elif type == "small":
|
66 |
| - model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes) |
| 89 | + model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes, **kwargs) |
67 | 90 | elif type == "plus":
|
68 | 91 | model = MODEL_LOOKUP["hovernet_plus"](
|
69 |
| - type_classes=ntypes, sem_classes=ntissues |
| 92 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
70 | 93 | )
|
71 | 94 | elif type == "small_plus":
|
72 | 95 | model = MODEL_LOOKUP["hovernet_small_plus"](
|
73 |
| - type_classes=ntypes, sem_classes=ntissues |
| 96 | + type_classes=ntypes, sem_classes=ntissues, **kwargs |
74 | 97 | )
|
75 | 98 | else:
|
76 | 99 | raise ValueError("Unknown model type or name.")
|
|
0 commit comments