Skip to content

Commit 196c5de

Browse files
committed
chore: rm torch-optimizer from the optional deps
1 parent 614fa4d commit 196c5de

File tree

3 files changed

+46
-62
lines changed

3 files changed

+46
-62
lines changed

README.md

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
**Python library for 2D cell/nuclei instance segmentation models written with [PyTorch](https://pytorch.org/).**
66

77
[![Generic badge](https://img.shields.io/badge/License-MIT-<COLOR>.svg?style=for-the-badge)](https://github.com/okunator/cellseg_models.pytorch/blob/master/LICENSE)
8-
[![PyTorch - Version](https://img.shields.io/badge/PYTORCH-1.8+-red?style=for-the-badge&logo=pytorch)](https://pytorch.org/)
8+
[![PyTorch - Version](https://img.shields.io/badge/PYTORCH-1.8.1+-red?style=for-the-badge&logo=pytorch)](https://pytorch.org/)
99
[![Python - Version](https://img.shields.io/badge/PYTHON-3.7+-red?style=for-the-badge&logo=python&logoColor=white)](https://www.python.org/)
1010
<br>
1111
[![Github Test](https://img.shields.io/github/workflow/status/okunator/cellseg_models.pytorch/Tests?label=Tests&logo=github&style=for-the-badge)](https://github.com/okunator/cellseg_models.pytorch/actions/workflows/tests.yml)
@@ -51,10 +51,9 @@ pip install cellseg-models-pytorch[all]
5151
- Pre-trained backbones/encoders from the [timm](https://github.com/rwightman/pytorch-image-models) library.
5252
- All the architectures can be augmented to output semantic segmentation outputs along with instance semgentation outputs (panoptic segmentation).
5353
- A lot of flexibility to modify the components of the model architectures.
54-
- Optimized inference methods.
54+
- Multi-GPU inference.
5555
- Popular training losses and benchmarking metrics.
5656
- Simple model training with [pytorch-lightning](https://www.pytorchlightning.ai/).
57-
- Popular optimizers for training (provided by [pytorch-optimizer](https://github.com/jettify/pytorch-optimizer)).
5857

5958
## Models
6059

@@ -85,10 +84,10 @@ pip install cellseg-models-pytorch[all]
8584
import cellseg_models_pytorch as csmp
8685
import torch
8786

88-
model = csmp.models.cellpose_base(type_classes=5) # num of cell types in training data=5.
87+
model = csmp.models.cellpose_base(type_classes=5)
8988
x = torch.rand([1, 3, 256, 256])
9089

91-
# NOTE: these outputs still need post-processing to obtain instance segmentation masks.
90+
# NOTE: the outputs still need post-processing.
9291
y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256]}
9392
```
9493

@@ -98,10 +97,10 @@ y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256]}
9897
import cellseg_models_pytorch as csmp
9998
import torch
10099

101-
model = csmp.models.cellpose_plus(type_classes=5, sem_classes=3) # num cell types and tissue types
100+
model = csmp.models.cellpose_plus(type_classes=5, sem_classes=3)
102101
x = torch.rand([1, 3, 256, 256])
103102

104-
# NOTE: these outputs still need post-processing to obtain instance and semantic segmentation masks.
103+
# NOTE: the outputs still need post-processing.
105104
y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [1, 3, 256, 256]}
106105
```
107106

@@ -110,27 +109,37 @@ y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [
110109
```python
111110
import cellseg_models_pytorch as csmp
112111

112+
# two decoder branches.
113+
decoders = ("cellpose", "sem")
114+
115+
# three segmentation heads from the decoders.
116+
heads = {
117+
"cellpose": {"cellpose": 2, "type": 5},
118+
"sem": {"sem": 3}
119+
}
120+
113121
model = csmp.CellPoseUnet(
114-
decoders=("cellpose", "sem"), # cellpose and semantic decoders
115-
heads={"cellpose": {"cellpose": 2, "type": 5}, "sem": {"sem": 3}}, # three output heads
116-
depth=5, # encoder depth
117-
out_channels=(256, 128, 64, 32, 16), # number of out channels at each decoder stage
118-
layer_depths=(4, 4, 4, 4, 4), # number of conv blocks at each decoder layer
119-
style_channels=256, # Number of style vector channels
120-
enc_name="resnet50", # timm encoder
121-
enc_pretrain=True, # imagenet pretrained encoder
122-
long_skip="unetpp", # use unet++ long skips. ("unet", "unetpp", "unet3p")
123-
merge_policy="sum", # ("cat", "sum")
124-
short_skip="residual", # residual short skips. ("basic", "residual", "dense")
125-
normalization="bcn", # batch-channel-normalization. ("bcn", "bn", "gn", "ln", "in")
126-
activation="gelu", # gelu activation instead of relu. Several options for this.
127-
convolution="wsconv", # weight standardized conv. ("wsconv", "conv", "scaled_wsconv")
128-
attention="se", # squeeze-and-excitation attention. ("se", "gc", "scse", "eca")
129-
pre_activate=False, # normalize and activation after convolution.
122+
decoders=decoders, # cellpose and semantic decoders
123+
heads=heads, # three output heads
124+
depth=5, # encoder depth
125+
out_channels=(256, 128, 64, 32, 16), # num out channels at each decoder stage
126+
layer_depths=(4, 4, 4, 4, 4), # num of conv blocks at each decoder layer
127+
style_channels=256, # num of style vector channels
128+
enc_name="resnet50", # timm encoder
129+
enc_pretrain=True, # imagenet pretrained encoder
130+
long_skip="unetpp", # unet++ long skips ("unet", "unetpp", "unet3p")
131+
merge_policy="sum", # concatenate long skips ("cat", "sum")
132+
short_skip="residual", # residual short skips ("basic", "residual", "dense")
133+
normalization="bcn", # batch-channel-normalization.
134+
activation="gelu", # gelu activation.
135+
convolution="wsconv", # weight standardized conv.
136+
attention="se", # squeeze-and-excitation attention.
137+
pre_activate=False, # normalize and activation after convolution.
130138
)
131139

132140
x = torch.rand([1, 3, 256, 256])
133-
# NOTE: these outputs still need post-processing to obtain instance and semantic segmentation masks.
141+
142+
# NOTE: the outputs still need post-processing.
134143
y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [1, 3, 256, 256]}
135144
```
136145

@@ -142,13 +151,20 @@ import cellseg_models_pytorch as csmp
142151
model = csmp.models.hovernet_base(type_classes=5)
143152
# returns {"hovernet": [B, 2, H, W], "type": [B, 5, H, W], "inst": [B, 2, H, W]}
144153

154+
# the final activations for each model output
155+
out_activations = {"hovernet": "tanh", "type": "softmax", "inst": "softmax"}
156+
157+
# models perform the poorest at the image boundaries, with overlapping patches this
158+
# causes issues which can be overcome by adding smoothing to the prediction boundaries
159+
out_boundary_weights = {"hovernet": True, "type": False, "inst": False}
160+
145161
# Sliding window inference for big images using overlapping patches
146162
inferer = csmp.inference.SlidingWindowInferer(
147163
model=model,
148164
input_folder="/path/to/images/",
149165
checkpoint_path="/path/to/model/weights/",
150-
out_activations={"hovernet": "tanh", "type": "softmax", "inst": "softmax"},
151-
out_boundary_weights={"hovernet": True, "type": False, "inst": False}, # smooths boundary effects
166+
out_activations=out_activations,
167+
out_boundary_weights=out_boundary_weights,
152168
instance_postproc="hovernet", # THE POST-PROCESSING METHOD
153169
patch_size=(256, 256),
154170
stride=128,
@@ -157,7 +173,8 @@ inferer = csmp.inference.SlidingWindowInferer(
157173
normalization="percentile", # same normalization as in training
158174
)
159175

160-
inferer.infer() # Run sliding window inference.
176+
# Run sliding window inference.
177+
inferer.infer()
161178

162179
inferer.out_masks
163180
# {"image1" :{"inst": [H, W], "type": [H, W]}, ..., "imageN" :{"inst": [H, W], "type": [H, W]}}

poetry.lock

Lines changed: 2 additions & 33 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ numba = "^0.55.2"
2424
tqdm = "^4.64.0"
2525
scikit-learn = "^1.0.2"
2626
pytorch-lightning = {version = "^1.6.0", optional = true}
27-
torch-optimizer = {version = "^0.3.0", optional = true}
2827
tables = {version = "^3.6.0", optional = true}
2928
albumentations = {version = "^1.0.0", optional = true}
3029
requests = {version = "^2.28.0", optional = true}
@@ -38,7 +37,6 @@ all = [
3837
"requests",
3938
"albumentations",
4039
"geojson",
41-
"torch-optimizer",
4240
"torchmetrics"
4341
]
4442

0 commit comments

Comments
 (0)