Skip to content

MicheleDamian/UNetBox

Repository files navigation

UNetBox UNetBox UNetBox

UNetBox

Network Schema

UNetBox is a PyTorch neural network for image segmentation that improves the popular UNet (arXiv:1505.04597) framework. It provides a box of techniques that are commonly used in the computer-vision field, but weren't included in UNet, in the form of plugins that can be easily enabled/disable. As for now, the following are implemented inside UNetBox:

The goal of the project is to add more techniques, as they become available, in order to push the state-of-the-art in image segmentation.

Ablation Study

In order to explore the influence of the components on UNetBox's performance, I perform ablation studies on two datasets: the Google's contrails identification dataset and the Sartorius' cell instance segmentation dataset.

The training uses the Adam optimizer, mixed precision, and a 1-cycle cosine scheduler to set the learning rate for both datasets. Training run until convergence of the focal loss on an out-of-sample dataset (3-fold cross validation).

The tables use the following abbreviations:

  • DE : Deformable Convolution replaces regular convolutions in the decoder blocks
  • SI : SiLU activation units
  • SE : Squeeze Excitation block added to last activation tensor before downsampling/upsampling
  • EC : Expansion followed by a Compression of the channels added after downsampling/upsampling
  • BN : BatchNorm
  • CT : Convolution Transposed replaces bilinear interpolation for upsampling

Google's Contrails Identification Dataset

Contrails dataset example

For this dataset each batch contains 64 3-channels, 256x256-pixels images. UNetBox's model has depth=4 and expansion_layer=16.

DE SI SE EC BN CT Focal Loss x10-3 Standard Error Training Time (m) # Parameters
1.065 ± .006 242 -
1.081 ± .029 195 3,032,017
1.08 ± .082 226 3,087,317
1.057 ± .008 132 3,033,937
1.04 ± .005 163 6,063,761
1.022 ± .022 147 6,063,761
1.016 ± .014 132 6,119,061
0.997 ± .018 - 6,177,881
0.999 ± .042 - 6,177,881
1.026 ± .028 - 6,233,181

A Jupyter Notebook that runs all the tests in the table is provided at tests/contrails_ablation_study.ipynb. You can refer to it as an example to use the package as well.

Sartorius' cell instance segmentation dataset

Cells dataset example

For this dataset each batch contains 16 grayscale, 512x512-pixels images. UNetBox's model has depth=2 and expansion_layer=8.

DE SI SE EC BN CT Focal Loss x10-2 Standard Error # Parameters
1.004 ± .053 43,353
1.085 ± .016 43,353
0.9321 ± .0690 43,545
0.9892 ± .0272 44,243
0.9593 ± .0115 43,545
0.9438 ± .0915 86,393
0.8707 ± .0168 87,283
0.8745 ± .0649 87,283
0.958 ± .0156 86,393
0.8714 ± .0316 87,283
0.8906 ± .0328 53,949
0.8835 ± .0236 53,949
0.8799 ± .0320 96,797
0.8722 ± .0248 54,839
0.8584 ± .0179 97,687
0.9126 ± .0349 53,949
0.8668 ± .0241 97,687

Cells dataset ablation study

Dependencies

UNetBox has been tested with the following dependencies:

  • Python >= 3.10
  • PyTorch >= 2.0
  • Torchvision >= 0.15
  • Timm >= 0.9

Usage

UNetBox can be downloaded by the following command:

git clone https://github.com/MicheleDamian/UNetBox.git $absolute_path_to_repo

Make sure that sys.path contains the (absolute) path to the repo and import it:

import sys
sys.path.insert(0, f'{absolute_path_to_repo}/UNetBox')

from unetbox.net import UNetBox

model = UNetBox()

About

A PyTorch neural network for image segmentation that improves the popular UNet framework.

Topics

Resources

License

Stars

Watchers

Forks