5
5
** Python library for 2D cell/nuclei instance segmentation models written with [ PyTorch] ( https://pytorch.org/ ) .**
6
6
7
7
[ ![ Generic badge] ( https://img.shields.io/badge/License-MIT-<COLOR>.svg?style=for-the-badge )] ( https://github.com/okunator/cellseg_models.pytorch/blob/master/LICENSE )
8
- [ ![ PyTorch - Version] ( https://img.shields.io/badge/PYTORCH-1.8+-red?style=for-the-badge&logo=pytorch )] ( https://pytorch.org/ )
8
+ [ ![ PyTorch - Version] ( https://img.shields.io/badge/PYTORCH-1.8.1 +-red?style=for-the-badge&logo=pytorch )] ( https://pytorch.org/ )
9
9
[ ![ Python - Version] ( https://img.shields.io/badge/PYTHON-3.7+-red?style=for-the-badge&logo=python&logoColor=white )] ( https://www.python.org/ )
10
10
<br >
11
11
[ ![ Github Test] ( https://img.shields.io/github/workflow/status/okunator/cellseg_models.pytorch/Tests?label=Tests&logo=github&style=for-the-badge )] ( https://github.com/okunator/cellseg_models.pytorch/actions/workflows/tests.yml )
@@ -51,10 +51,9 @@ pip install cellseg-models-pytorch[all]
51
51
- Pre-trained backbones/encoders from the [ timm] ( https://github.com/rwightman/pytorch-image-models ) library.
52
52
- All the architectures can be augmented to output semantic segmentation outputs along with instance semgentation outputs (panoptic segmentation).
53
53
- A lot of flexibility to modify the components of the model architectures.
54
- - Optimized inference methods .
54
+ - Multi-GPU inference.
55
55
- Popular training losses and benchmarking metrics.
56
56
- Simple model training with [ pytorch-lightning] ( https://www.pytorchlightning.ai/ ) .
57
- - Popular optimizers for training (provided by [ pytorch-optimizer] ( https://github.com/jettify/pytorch-optimizer ) ).
58
57
59
58
## Models
60
59
@@ -85,10 +84,10 @@ pip install cellseg-models-pytorch[all]
85
84
import cellseg_models_pytorch as csmp
86
85
import torch
87
86
88
- model = csmp.models.cellpose_base(type_classes = 5 ) # num of cell types in training data=5.
87
+ model = csmp.models.cellpose_base(type_classes = 5 )
89
88
x = torch.rand([1 , 3 , 256 , 256 ])
90
89
91
- # NOTE : these outputs still need post-processing to obtain instance segmentation masks .
90
+ # NOTE : the outputs still need post-processing.
92
91
y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256]}
93
92
```
94
93
@@ -98,10 +97,10 @@ y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256]}
98
97
import cellseg_models_pytorch as csmp
99
98
import torch
100
99
101
- model = csmp.models.cellpose_plus(type_classes = 5 , sem_classes = 3 ) # num cell types and tissue types
100
+ model = csmp.models.cellpose_plus(type_classes = 5 , sem_classes = 3 )
102
101
x = torch.rand([1 , 3 , 256 , 256 ])
103
102
104
- # NOTE : these outputs still need post-processing to obtain instance and semantic segmentation masks .
103
+ # NOTE : the outputs still need post-processing.
105
104
y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [1, 3, 256, 256]}
106
105
```
107
106
@@ -110,27 +109,37 @@ y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [
110
109
``` python
111
110
import cellseg_models_pytorch as csmp
112
111
112
+ # two decoder branches.
113
+ decoders = (" cellpose" , " sem" )
114
+
115
+ # three segmentation heads from the decoders.
116
+ heads = {
117
+ " cellpose" : {" cellpose" : 2 , " type" : 5 },
118
+ " sem" : {" sem" : 3 }
119
+ }
120
+
113
121
model = csmp.CellPoseUnet(
114
- decoders = ( " cellpose " , " sem " ), # cellpose and semantic decoders
115
- heads = { " cellpose " : { " cellpose " : 2 , " type " : 5 }, " sem " : { " sem " : 3 }}, # three output heads
116
- depth = 5 , # encoder depth
117
- out_channels = (256 , 128 , 64 , 32 , 16 ), # number of out channels at each decoder stage
118
- layer_depths = (4 , 4 , 4 , 4 , 4 ), # number of conv blocks at each decoder layer
119
- style_channels = 256 , # Number of style vector channels
120
- enc_name = " resnet50" , # timm encoder
121
- enc_pretrain = True , # imagenet pretrained encoder
122
- long_skip = " unetpp" , # use unet++ long skips. ("unet", "unetpp", "unet3p")
123
- merge_policy = " sum" , # ("cat", "sum")
124
- short_skip = " residual" , # residual short skips. ("basic", "residual", "dense")
125
- normalization = " bcn" , # batch-channel-normalization. ("bcn", "bn", "gn", "ln", "in")
126
- activation = " gelu" , # gelu activation instead of relu. Several options for this .
127
- convolution = " wsconv" , # weight standardized conv. ("wsconv", "conv", "scaled_wsconv")
128
- attention = " se" , # squeeze-and-excitation attention. ("se", "gc", "scse", "eca")
129
- pre_activate = False , # normalize and activation after convolution.
122
+ decoders = decoders, # cellpose and semantic decoders
123
+ heads = heads, # three output heads
124
+ depth = 5 , # encoder depth
125
+ out_channels = (256 , 128 , 64 , 32 , 16 ), # num out channels at each decoder stage
126
+ layer_depths = (4 , 4 , 4 , 4 , 4 ), # num of conv blocks at each decoder layer
127
+ style_channels = 256 , # num of style vector channels
128
+ enc_name = " resnet50" , # timm encoder
129
+ enc_pretrain = True , # imagenet pretrained encoder
130
+ long_skip = " unetpp" , # unet++ long skips ("unet", "unetpp", "unet3p")
131
+ merge_policy = " sum" , # concatenate long skips ("cat", "sum")
132
+ short_skip = " residual" , # residual short skips ("basic", "residual", "dense")
133
+ normalization = " bcn" , # batch-channel-normalization.
134
+ activation = " gelu" , # gelu activation .
135
+ convolution = " wsconv" , # weight standardized conv.
136
+ attention = " se" , # squeeze-and-excitation attention.
137
+ pre_activate = False , # normalize and activation after convolution.
130
138
)
131
139
132
140
x = torch.rand([1 , 3 , 256 , 256 ])
133
- # NOTE : these outputs still need post-processing to obtain instance and semantic segmentation masks.
141
+
142
+ # NOTE : the outputs still need post-processing.
134
143
y = model(x) # {"cellpose": [1, 2, 256, 256], "type": [1, 5, 256, 256], "sem": [1, 3, 256, 256]}
135
144
```
136
145
@@ -142,13 +151,20 @@ import cellseg_models_pytorch as csmp
142
151
model = csmp.models.hovernet_base(type_classes = 5 )
143
152
# returns {"hovernet": [B, 2, H, W], "type": [B, 5, H, W], "inst": [B, 2, H, W]}
144
153
154
+ # the final activations for each model output
155
+ out_activations = {" hovernet" : " tanh" , " type" : " softmax" , " inst" : " softmax" }
156
+
157
+ # models perform the poorest at the image boundaries, with overlapping patches this
158
+ # causes issues which can be overcome by adding smoothing to the prediction boundaries
159
+ out_boundary_weights = {" hovernet" : True , " type" : False , " inst" : False }
160
+
145
161
# Sliding window inference for big images using overlapping patches
146
162
inferer = csmp.inference.SlidingWindowInferer(
147
163
model = model,
148
164
input_folder = " /path/to/images/" ,
149
165
checkpoint_path = " /path/to/model/weights/" ,
150
- out_activations = { " hovernet " : " tanh " , " type " : " softmax " , " inst " : " softmax " } ,
151
- out_boundary_weights = { " hovernet " : True , " type " : False , " inst " : False }, # smooths boundary effects
166
+ out_activations = out_activations ,
167
+ out_boundary_weights = out_boundary_weights,
152
168
instance_postproc = " hovernet" , # THE POST-PROCESSING METHOD
153
169
patch_size = (256 , 256 ),
154
170
stride = 128 ,
@@ -157,7 +173,8 @@ inferer = csmp.inference.SlidingWindowInferer(
157
173
normalization = " percentile" , # same normalization as in training
158
174
)
159
175
160
- inferer.infer() # Run sliding window inference.
176
+ # Run sliding window inference.
177
+ inferer.infer()
161
178
162
179
inferer.out_masks
163
180
# {"image1" :{"inst": [H, W], "type": [H, W]}, ..., "imageN" :{"inst": [H, W], "type": [H, W]}}
0 commit comments