Skip to content

[Feature] add model script, training recipe and training weights of TNT #507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions configs/tnt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

# TNT
> [Transformer in Transformer](https://arxiv.org/pdf/2103.00112.pdf)

## Introduction
![122160150-ff1bca80-cea1-11eb-9329-be5031bad78e](https://user-images.githubusercontent.com/41994229/224009923-02ad8d88-1cad-429e-b322-dc80660e8cbd.png)

Illustration of the proposed Transformer-iN-Transformer (TNT) framework. The inner
transformer block is shared in the same layer. The word position encodings are shared across visual
sentences.
## Results

**Implementation and configs for training were taken and adjusted from [this repository](https://gitee.com/cvisionlab/models/tree/tnt/release/research/cv/tnt), which implements tnt model in mindspore.**

Our reproduced model performance on ImageNet-1K is reported as follows.
<div align="center">

| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|
| tnt_small | 8xRTX3090 | 74.14 | 92.07 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/tnt/tnt_s_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_s_patch16_224_ep138_acc_0.74.ckpt) |
| tnt_small | Converted from PyTorch | 72.51 | 90.68 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/tnt/tnt_s_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_s_converted_0.718.ckpt) |
| tnt_base | Converted from PyTorch | 79.62 | 94.81 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/tnt/tnt_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_b_converted_0.795.ckpt) |

</div>

#### Notes

- Context: The weights in the table were taken from [official repository](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch) and converted to mindspore format
- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.

## Quick Start

### Preparation

#### Installation
Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV.

#### Dataset Preparation
Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.

### Training

* Distributed Training


```shell
# distrubted training on multiple GPU/Ascend devices
mpirun -n 8 python train.py --config configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/imagenet --distributed True
```

> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.

Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.

For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).

**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.

* Standalone Training

If you want to train or finetune the model on a smaller dataset without distributed training, please run:

```shell
# standalone training on a CPU/GPU/Ascend device
python train.py --config configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/dataset --distribute False
```

### Validation

To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.

```shell
python validate.py -c configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
```

Or use '--pretrained' parameter to automatically download the checkpoint.

```shell
python validate.py -c configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/imagenet --pretrained
```

### Deployment

Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV.

## References

Paper - https://arxiv.org/pdf/2103.00112.pdf

Official PyTorch implementation - https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch

Official Mindspore implementation - https://gitee.com/cvisionlab/models/tree/tnt/release/research/cv/tnt
68 changes: 68 additions & 0 deletions configs/tnt/tnt_b_gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# system
mode: 0
distribute: False
num_parallel_workers: 1
val_while_train: True

# dataset
dataset: 'imagenet'
data_dir: 'path/to/imagenet/'
shuffle: True
dataset_download: False
batch_size: 16
drop_remainder: True
val_split: val
train_split: val

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
auto_augment: 'randaug-m9-mstd0.1-inc1'
interpolation: bicubic
re_prob: 0.25
re_value: 'random'
cutmix: 1.0
mixup: 0.8
mixup_prob: 1.
mixup_mode: batch
switch_prob: 0.5
crop_pct: 0.9

# model
model: 'tnt_base'
num_classes: 1000
pretrained: False
ckpt_path: ''

keep_checkpoints_max: 10
ckpt_save_dir: './ckpt'

epoch_size: 300
dataset_sink_mode: True
amp_level: 'O0'
ema: False
clip_grad: True
clip_value: 5.0

drop_rate: 0.
drop_path_rate: 0.1

# loss
loss: 'CE'
label_smoothing: 0.1

# lr scheduler
lr_scheduler: 'cosine_decay'
lr: 0.0005
warmup_epochs: 20
warmup_factor: 0.00014
min_lr: 0.000006

# optimizer
opt: 'adamw'
momentum: 0.9
weight_decay: 0.05
dynamic_loss_scale: True
eps: 1e-8
68 changes: 68 additions & 0 deletions configs/tnt/tnt_s_gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# system
mode: 0
distribute: False
num_parallel_workers: 1
val_while_train: True

# dataset
dataset: 'imagenet'
data_dir: 'path/to/imagenet/'
shuffle: True
dataset_download: False
batch_size: 32
drop_remainder: True
val_split: val
train_split: val

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
auto_augment: 'randaug-m9-mstd0.1-inc1'
interpolation: bicubic
re_prob: 0.25
re_value: 'random'
cutmix: 1.0
mixup: 0.8
mixup_prob: 1.
mixup_mode: batch
switch_prob: 0.5
crop_pct: 0.9

# model
model: 'tnt_small'
num_classes: 1000
pretrained: False
ckpt_path: ''

keep_checkpoints_max: 10
ckpt_save_dir: './ckpt'

epoch_size: 300
dataset_sink_mode: True
amp_level: 'O0'
ema: False
clip_grad: True
clip_value: 5.0

drop_rate: 0.
drop_path_rate: 0.1

# loss
loss: 'CE'
label_smoothing: 0.1

# lr scheduler
lr_scheduler: 'cosine_decay'
lr: 0.0005
warmup_epochs: 20
warmup_factor: 0.00014
min_lr: 0.000006

# optimizer
opt: 'adamw'
momentum: 0.9
weight_decay: 0.05
dynamic_loss_scale: True
eps: 1e-8
3 changes: 3 additions & 0 deletions mindcv/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
sknet,
squeezenet,
swin_transformer,
tnt,
vgg,
visformer,
vit,
Expand Down Expand Up @@ -79,6 +80,7 @@
from .sknet import *
from .squeezenet import *
from .swin_transformer import *
from .tnt import *
from .utils import *
from .vgg import *
from .visformer import *
Expand Down Expand Up @@ -125,6 +127,7 @@
__all__.extend(sknet.__all__)
__all__.extend(squeezenet.__all__)
__all__.extend(swin_transformer.__all__)
__all__.extend(tnt.__all__)
__all__.extend(vgg.__all__)
__all__.extend(visformer.__all__)
__all__.extend(vit.__all__)
Expand Down
Loading