Skip to content

feat: add model script, training configs and trained weights of cait #742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions configs/cait/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Going deeper with Image Transformers

> [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239)

## Introduction

CaiT is built based on ViT but made two contributions to improve model performance.
Firstly, Layerscale is introduced to facilitate the convergence.
Secondly, class-attention offers a more effective processing of the class embedding.
By combing these parts, Cait could get a SOTA performance on ImageNet-1K dataset.


## Results

Our reproduced model performance on ImageNet-1K is reported as follows.

<div align="center">

| Model | Context | Top-1 (%) | Top-5 (%) | Params(M) | Recipe | Download |
|----------------| -------- |----------|-----------|-----------|--------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|
| cait_xxs24_224 | D910x8-G | 77.71 | 94.10 | 11.94 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_xxs24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_xxs24-31b307a8.ckpt) |
| cait_xs24_224 | D910x8-G | 81.29 | 95.60 | 26.53 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_xs24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_xs24-ba0c2053.ckpt) |
| cait_s24_224 | D910x8-G | 82.25 | 95.95 | 46.88 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_s24_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_s24-0a06be71.ckpt) |
| cait_s36_224 | D910x8-G | 82.11 | 95.84 | 68.16 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/cait/cait_s36_224.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/cait/cait_s36-2e42bfc8.ckpt) |


</div>

#### Notes

- Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode.
- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.

## Quick Start

### Preparation

#### Installation

Please refer to the [installation instruction](https://github.com/mindspore-lab/mindcv#installation) in MindCV.

#### Dataset Preparation

Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.

### Training

* Distributed Training

It is easy to reproduce the reported results with the pre-defined training recipe. For distributed training on multiple Ascend 910 devices, please run

```shell
# distributed training on multiple GPU/Ascend devices
mpirun -n 8 python train.py --config configs/cait/cait_xxs24_224.yaml --data_dir /path/to/imagenet
```
> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.

Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.

For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).

**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.

* Standalone Training

If you want to train or finetune the model on a smaller dataset without distributed training, please run:

```shell
# standalone training on a CPU/GPU/Ascend device
python train.py --config configs/cait/cait_xxs24_224.yaml --data_dir /path/to/dataset --distribute False
```

### Validation

To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.

```
python validate.py -c configs/cait/cait_xxs24_224.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
```

### Deployment

Please refer to the [deployment tutorial](https://mindspore-lab.github.io/mindcv/tutorials/deployment/).

## References

<!--- Guideline: Citation format should follow GB/T 7714. -->
[1] Touvron H, Cord M, Sablayrolles A, et al. Going deeper with image transformers[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 32-42.
61 changes: 61 additions & 0 deletions configs/cait/cait_s24_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# system config
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True

# dataset config
dataset: 'imagenet'
data_dir: '/path/to/imagenet'
shuffle: True
dataset_download: False
batch_size: 64
drop_remainder: True

# augmentation config
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.
interpolation: 'bicubic'
auto_augment: 'randaug-m9-mstd0.5-inc1'
re_prob: 0.25
mixup: 0.8
cutmix: 1
color_jitter: 0.3
crop_pct: 1.0
ema: True
ema_decay: 0.99996

# model config
model: 'cait_s24_224'
num_classes: 1000
pretrained: False
ckpt_path: ''
keep_checkpoint_max: 10
ckpt_save_dir: '/cache/output/'
epoch_size: 400
dataset_sink_mode: True
amp_level: 'O2'
drop_path_rate: 0.1

# loss config
loss: 'CE'
label_smoothing: 0.1

# lr scheduler config
scheduler: 'warmup_cosine_decay'
lr: 0.001
min_lr: 0.000001
warmup_epochs: 30
decay_epochs: 370
num_cycles: 2


# optimizer config
opt: 'adamw'
weight_decay: 0.05
filter_bias_and_bn: True
loss_scale: 1024
use_nesterov: False
61 changes: 61 additions & 0 deletions configs/cait/cait_s36_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# system config
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True

# dataset config
dataset: 'imagenet'
data_dir: '/path/to/iamgenet'
shuffle: True
dataset_download: False
batch_size: 64
drop_remainder: True

# augmentation config
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.
interpolation: 'bicubic'
auto_augment: 'randaug-m9-mstd0.5-inc1'
re_prob: 0.25
mixup: 0.8
cutmix: 1
color_jitter: 0.3
crop_pct: 1.0
ema: True
ema_decay: 0.99996

# model config
model: 'cait_s36_224'
num_classes: 1000
pretrained: False
ckpt_path: ''
keep_checkpoint_max: 10
ckpt_save_dir: './ckpt'
epoch_size: 400
dataset_sink_mode: True
amp_level: 'O2'
drop_path_rate: 0.1

# loss config
loss: 'CE'
label_smoothing: 0.1

# lr scheduler config
scheduler: 'warmup_cosine_decay'
lr: 0.002
min_lr: 0.000001
warmup_epochs: 30
decay_epochs: 370
num_cycles: 2


# optimizer config
opt: 'adamw'
weight_decay: 0.05
filter_bias_and_bn: True
loss_scale: 1024
use_nesterov: False
61 changes: 61 additions & 0 deletions configs/cait/cait_xs24_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# system config
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True

# dataset config
dataset: 'imagenet'
data_dir: '/path/to/imagenet'
shuffle: True
dataset_download: False
batch_size: 64
drop_remainder: True

# augmentation config
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.
interpolation: 'bicubic'
auto_augment: 'randaug-m9-mstd0.5-inc1'
re_prob: 0.25
mixup: 0.8
cutmix: 1
color_jitter: 0.3
crop_pct: 1.0
ema: True
ema_decay: 0.99996

# model config
model: 'cait_xs24_224'
num_classes: 1000 #
pretrained: False #
ckpt_path: ''
keep_checkpoint_max: 10
ckpt_save_dir: './ckpt'
epoch_size: 400
dataset_sink_mode: True
amp_level: 'O2'
drop_path_rate: 0.1

# loss config
loss: 'CE'
label_smoothing: 0.1

# lr scheduler config
scheduler: 'warmup_cosine_decay'
lr: 0.001
min_lr: 0.000001
warmup_epochs: 40
decay_epochs: 360
num_cycles: 2


# optimizer config
opt: 'adamw'
weight_decay: 0.05
filter_bias_and_bn: True
loss_scale: 1024
use_nesterov: False
61 changes: 61 additions & 0 deletions configs/cait/cait_xxs24_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# system config
mode: 0
distribute: True
num_parallel_workers: 8
val_while_train: True

# dataset config
dataset: 'imagenet'
data_dir: '/path/to/dataset'
shuffle: True
dataset_download: False
batch_size: 128
drop_remainder: True

# augmentation config
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.
interpolation: 'bicubic'
auto_augment: 'randaug-m9-mstd0.5-inc1'
re_prob: 0.25
mixup: 0.8
cutmix: 1
color_jitter: 0.3
crop_pct: 1.0
ema: True
ema_decay: 0.99996

# model config
model: 'cait_xxs24_224'
num_classes: 1000
pretrained: False
ckpt_path: ''
keep_checkpoint_max: 10
ckpt_save_dir: './ckpt'
epoch_size: 500
dataset_sink_mode: True
amp_level: 'O2'
drop_path_rate: 0.1

# loss config
loss: 'CE'
label_smoothing: 0.1

# lr scheduler config
scheduler: 'warmup_cosine_decay'
lr: 0.001
min_lr: 0.000001
warmup_epochs: 40
decay_epochs: 460
num_cycles: 2


# optimizer config
opt: 'adamw'
weight_decay: 0.025
filter_bias_and_bn: True
loss_scale: 1024
use_nesterov: False
Loading