diff --git a/configs/tnt/README.md b/configs/tnt/README.md
new file mode 100644
index 000000000..f4756e9ae
--- /dev/null
+++ b/configs/tnt/README.md
@@ -0,0 +1,92 @@
+
+# TNT
+> [Transformer in Transformer](https://arxiv.org/pdf/2103.00112.pdf)
+
+## Introduction
+
+
+Illustration of the proposed Transformer-iN-Transformer (TNT) framework. The inner
+transformer block is shared in the same layer. The word position encodings are shared across visual
+sentences.
+## Results
+
+**Implementation and configs for training were taken and adjusted from [this repository](https://gitee.com/cvisionlab/models/tree/tnt/release/research/cv/tnt), which implements tnt model in mindspore.**
+
+Our reproduced model performance on ImageNet-1K is reported as follows.
+
+
+| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
+|----------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------|
+| tnt_small | 8xRTX3090 | 74.14 | 92.07 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/tnt/tnt_s_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_s_patch16_224_ep138_acc_0.74.ckpt) |
+| tnt_small | Converted from PyTorch | 72.51 | 90.68 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/tnt/tnt_s_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_s_converted_0.718.ckpt) |
+| tnt_base | Converted from PyTorch | 79.62 | 94.81 | - | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/tnt/tnt_b_gpu.yaml) | [weights](https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_b_converted_0.795.ckpt) |
+
+
+
+#### Notes
+
+- Context: The weights in the table were taken from [official repository](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch) and converted to mindspore format
+- Top-1 and Top-5: Accuracy reported on the validation set of ImageNet-1K.
+
+## Quick Start
+
+### Preparation
+
+#### Installation
+Please refer to the [installation instruction](https://github.com/mindspore-ecosystem/mindcv#installation) in MindCV.
+
+#### Dataset Preparation
+Please download the [ImageNet-1K](https://www.image-net.org/challenges/LSVRC/2012/index.php) dataset for model training and validation.
+
+### Training
+
+* Distributed Training
+
+
+```shell
+# distrubted training on multiple GPU/Ascend devices
+mpirun -n 8 python train.py --config configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/imagenet --distributed True
+```
+
+> If the script is executed by the root user, the `--allow-run-as-root` parameter must be added to `mpirun`.
+
+Similarly, you can train the model on multiple GPU devices with the above `mpirun` command.
+
+For detailed illustration of all hyper-parameters, please refer to [config.py](https://github.com/mindspore-lab/mindcv/blob/main/config.py).
+
+**Note:** As the global batch size (batch_size x num_devices) is an important hyper-parameter, it is recommended to keep the global batch size unchanged for reproduction or adjust the learning rate linearly to a new global batch size.
+
+* Standalone Training
+
+If you want to train or finetune the model on a smaller dataset without distributed training, please run:
+
+```shell
+# standalone training on a CPU/GPU/Ascend device
+python train.py --config configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/dataset --distribute False
+```
+
+### Validation
+
+To validate the accuracy of the trained model, you can use `validate.py` and parse the checkpoint path with `--ckpt_path`.
+
+```shell
+python validate.py -c configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/imagenet --ckpt_path /path/to/ckpt
+```
+
+Or use '--pretrained' parameter to automatically download the checkpoint.
+
+```shell
+python validate.py -c configs/tnt/tnt_s_gpu.yaml --data_dir /path/to/imagenet --pretrained
+```
+
+### Deployment
+
+Please refer to the [deployment tutorial](https://github.com/mindspore-lab/mindcv/blob/main/tutorials/deployment.md) in MindCV.
+
+## References
+
+Paper - https://arxiv.org/pdf/2103.00112.pdf
+
+Official PyTorch implementation - https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
+
+Official Mindspore implementation - https://gitee.com/cvisionlab/models/tree/tnt/release/research/cv/tnt
diff --git a/configs/tnt/tnt_b_gpu.yaml b/configs/tnt/tnt_b_gpu.yaml
new file mode 100644
index 000000000..fd72f4934
--- /dev/null
+++ b/configs/tnt/tnt_b_gpu.yaml
@@ -0,0 +1,68 @@
+# system
+mode: 0
+distribute: False
+num_parallel_workers: 1
+val_while_train: True
+
+# dataset
+dataset: 'imagenet'
+data_dir: 'path/to/imagenet/'
+shuffle: True
+dataset_download: False
+batch_size: 16
+drop_remainder: True
+val_split: val
+train_split: val
+
+# augmentation
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+auto_augment: 'randaug-m9-mstd0.1-inc1'
+interpolation: bicubic
+re_prob: 0.25
+re_value: 'random'
+cutmix: 1.0
+mixup: 0.8
+mixup_prob: 1.
+mixup_mode: batch
+switch_prob: 0.5
+crop_pct: 0.9
+
+# model
+model: 'tnt_base'
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+
+keep_checkpoints_max: 10
+ckpt_save_dir: './ckpt'
+
+epoch_size: 300
+dataset_sink_mode: True
+amp_level: 'O0'
+ema: False
+clip_grad: True
+clip_value: 5.0
+
+drop_rate: 0.
+drop_path_rate: 0.1
+
+# loss
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler
+lr_scheduler: 'cosine_decay'
+lr: 0.0005
+warmup_epochs: 20
+warmup_factor: 0.00014
+min_lr: 0.000006
+
+# optimizer
+opt: 'adamw'
+momentum: 0.9
+weight_decay: 0.05
+dynamic_loss_scale: True
+eps: 1e-8
diff --git a/configs/tnt/tnt_s_gpu.yaml b/configs/tnt/tnt_s_gpu.yaml
new file mode 100644
index 000000000..e2e44b815
--- /dev/null
+++ b/configs/tnt/tnt_s_gpu.yaml
@@ -0,0 +1,68 @@
+# system
+mode: 0
+distribute: False
+num_parallel_workers: 1
+val_while_train: True
+
+# dataset
+dataset: 'imagenet'
+data_dir: 'path/to/imagenet/'
+shuffle: True
+dataset_download: False
+batch_size: 32
+drop_remainder: True
+val_split: val
+train_split: val
+
+# augmentation
+image_resize: 224
+scale: [0.08, 1.0]
+ratio: [0.75, 1.333]
+hflip: 0.5
+auto_augment: 'randaug-m9-mstd0.1-inc1'
+interpolation: bicubic
+re_prob: 0.25
+re_value: 'random'
+cutmix: 1.0
+mixup: 0.8
+mixup_prob: 1.
+mixup_mode: batch
+switch_prob: 0.5
+crop_pct: 0.9
+
+# model
+model: 'tnt_small'
+num_classes: 1000
+pretrained: False
+ckpt_path: ''
+
+keep_checkpoints_max: 10
+ckpt_save_dir: './ckpt'
+
+epoch_size: 300
+dataset_sink_mode: True
+amp_level: 'O0'
+ema: False
+clip_grad: True
+clip_value: 5.0
+
+drop_rate: 0.
+drop_path_rate: 0.1
+
+# loss
+loss: 'CE'
+label_smoothing: 0.1
+
+# lr scheduler
+lr_scheduler: 'cosine_decay'
+lr: 0.0005
+warmup_epochs: 20
+warmup_factor: 0.00014
+min_lr: 0.000006
+
+# optimizer
+opt: 'adamw'
+momentum: 0.9
+weight_decay: 0.05
+dynamic_loss_scale: True
+eps: 1e-8
diff --git a/mindcv/models/__init__.py b/mindcv/models/__init__.py
index d0521efff..7f075eed7 100644
--- a/mindcv/models/__init__.py
+++ b/mindcv/models/__init__.py
@@ -37,6 +37,7 @@
sknet,
squeezenet,
swin_transformer,
+ tnt,
vgg,
visformer,
vit,
@@ -79,6 +80,7 @@
from .sknet import *
from .squeezenet import *
from .swin_transformer import *
+from .tnt import *
from .utils import *
from .vgg import *
from .visformer import *
@@ -125,6 +127,7 @@
__all__.extend(sknet.__all__)
__all__.extend(squeezenet.__all__)
__all__.extend(swin_transformer.__all__)
+__all__.extend(tnt.__all__)
__all__.extend(vgg.__all__)
__all__.extend(visformer.__all__)
__all__.extend(vit.__all__)
diff --git a/mindcv/models/tnt.py b/mindcv/models/tnt.py
new file mode 100644
index 000000000..d01368bed
--- /dev/null
+++ b/mindcv/models/tnt.py
@@ -0,0 +1,574 @@
+"""Transformer in Transformer (TNT)"""
+import math
+
+import numpy as np
+from scipy.stats import truncnorm
+
+import mindspore.common.initializer as weight_init
+import mindspore.nn as nn
+import mindspore.ops.operations as P
+from mindspore import Parameter, Tensor
+from mindspore import dtype as mstype
+from mindspore import ops
+
+from .registry import register_model
+from .utils import _ntuple, load_pretrained, make_divisible
+
+__all__ = [
+ "tnt_small",
+ "tnt_base"
+]
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000,
+ 'first_conv': 'patch_embed.proj',
+ 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ "tnt_small": _cfg(
+ url="https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_s_patch16_224_ep138_acc_0.74.ckpt"),
+ "tnt_base": _cfg(url="https://storage.googleapis.com/huawei-mindspore-hk/TNT/tnt_b_converted_0.795.ckpt")
+}
+
+
+class DropPath(nn.Cell):
+ """
+ Drop paths (Stochastic Depth) per sample
+ (when applied in main path of residual blocks).
+
+ Args:
+ drop_prob(float): Probability of dropout
+ ndim(int): Number of dimensions in input tensor
+
+ Returns:
+ Tensor: Output tensor after dropout
+ """
+
+ def __init__(self, drop_prob, ndim):
+ super(DropPath, self).__init__()
+ self.drop = nn.Dropout(keep_prob=1 - drop_prob)
+ shape = (1,) + (1,) * (ndim + 1)
+ self.ndim = ndim
+ self.mask = Tensor(np.ones(shape), dtype=mstype.float32)
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ if not self.training:
+ return x
+ mask = ops.Tile()(self.mask, (x.shape[0],) + (1,) * (self.ndim + 1))
+ out = self.drop(mask)
+ out = out * x
+ return out
+
+
+class DropPath1D(DropPath):
+ """DropPath1D"""
+
+ def __init__(self, drop_prob):
+ super(DropPath1D, self).__init__(drop_prob=drop_prob, ndim=1)
+
+
+def trunc_array(shape, sigma=0.02):
+ """output truncnormal array in shape"""
+ return truncnorm.rvs(-2, 2, loc=0, scale=sigma, size=shape, random_state=None)
+
+
+to_2tuple = _ntuple(2)
+
+
+class UnfoldKernelEqPatch(nn.Cell):
+ """
+ UnfoldKernelEqPatch with better performance
+
+ Args:
+ kernel_size(tuple): kernel size (along each side)
+ strides(tuple): Stride (along each side)
+
+ Returns:
+ Tensor, output tensor
+ """
+
+ def __init__(self, kernel_size, strides):
+ super(UnfoldKernelEqPatch, self).__init__()
+ assert kernel_size == strides
+ self.kernel_size = kernel_size
+ self.reshape = P.Reshape()
+ self.transpose = P.Transpose()
+
+ def construct(self, *inputs, **kwargs):
+ inputs = inputs[0]
+ b, c, h, w = inputs.shape
+ inputs = self.reshape(inputs,
+ (b, c, h // self.kernel_size[0], self.kernel_size[0], w))
+ inputs = self.transpose(inputs, (0, 2, 1, 3, 4))
+ inputs = self.reshape(inputs, (-1, c, self.kernel_size[0], w // self.kernel_size[1], self.kernel_size[1]))
+ inputs = self.transpose(inputs, (0, 3, 1, 2, 4))
+ inputs = self.reshape(inputs, (-1, c, self.kernel_size[0], self.kernel_size[1]))
+ # inputs = self.reshape(
+ # inputs,
+ # (B, C,
+ # H // self.kernel_size[0], self.kernel_size[0],
+ # W // self.kernel_size[1], self.kernel_size[1])
+ # )
+ # inputs = self.transpose(inputs, )
+
+ return inputs
+
+
+class PatchEmbed(nn.Cell):
+ """
+ Image to Visual Word Embedding
+
+ Args:
+ img_size(int): Image size (side, px)
+ patch_size(int): Output patch size (side, px)
+ in_chans(int): Number of input channels
+ outer_dim(int): Number of output features (not used)
+ inner_dim(int): Number of internal features
+ inner_stride(int): Stride of patches (px)
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
+ outer_dim=768, inner_dim=24, inner_stride=4):
+ super().__init__()
+ _ = outer_dim
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.inner_dim = inner_dim
+ self.num_words = math.ceil(patch_size[0] / inner_stride) * math.ceil(patch_size[1] / inner_stride)
+
+ self.unfold = UnfoldKernelEqPatch(kernel_size=patch_size, strides=patch_size)
+ # unfold_shape = [1, *patch_size, 1]
+ # self.unfold = nn.Unfold(unfold_shape, unfold_shape, unfold_shape)
+ self.proj = nn.Conv2d(in_channels=in_chans, out_channels=inner_dim, kernel_size=7, stride=inner_stride,
+ pad_mode='pad', padding=3, has_bias=True)
+
+ self.reshape = P.Reshape()
+ self.transpose = P.Transpose()
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ b, _ = x.shape[0], x.shape[1]
+ x = self.unfold(x) # B, Ck2, N
+ x = self.proj(x) # B*N, C, 8, 8
+ x = self.reshape(x, (b * self.num_patches, self.inner_dim, -1,)) # B*N, 8*8, C
+ x = self.transpose(x, (0, 2, 1))
+ return x
+
+
+class Attention(nn.Cell):
+ """
+ Attention layer
+
+ Args:
+ dim(int): Number of output features
+ hidden_dim(int): Number of hidden features
+ num_heads(int): Number of output heads
+ qkv_bias(bool): Enable bias weights in Qk / v dense layers
+ qk_scale(float): Qk scale (multiplier)
+ attn_drop(float): Attention dropout rate
+ proj_drop(float): Projection dropout rate
+ """
+
+ def __init__(self, dim, hidden_dim,
+ num_heads=8, qkv_bias=False, qk_scale=None,
+ attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ head_dim = hidden_dim // num_heads
+ self.head_dim = head_dim
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qk = nn.Dense(in_channels=dim, out_channels=hidden_dim * 2, has_bias=qkv_bias)
+ # self.q = nn.Dense(in_channels=dim, out_channels=hidden_dim, has_bias=qkv_bias)
+ # self.k = nn.Dense(in_channels=dim, out_channels=hidden_dim, has_bias=qkv_bias)
+ self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias)
+ self.attn_drop = nn.Dropout(keep_prob=1.0 - attn_drop)
+ self.proj = nn.Dense(in_channels=dim, out_channels=dim, has_bias=True)
+ self.proj_drop = nn.Dropout(keep_prob=1.0 - proj_drop)
+ self.softmax = nn.Softmax(axis=-1)
+ self.matmul = P.BatchMatMul()
+
+ self.reshape = P.Reshape()
+ self.transpose = P.Transpose()
+
+ def construct(self, *inputs, **kwargs):
+ """Attention construct"""
+ x = inputs[0]
+ b, n, _ = x.shape
+ qk = self.reshape(self.qk(x),
+ (b, n, 2, self.num_heads, self.head_dim))
+ qk = self.transpose(qk, (2, 0, 3, 1, 4))
+ q, k = qk[0], qk[1]
+
+ v = self.reshape(self.v(x),
+ (b, n, self.num_heads, -1))
+ v = self.transpose(v, (0, 2, 1, 3))
+
+ attn = self.matmul(q, self.transpose(k, (0, 1, 3, 2))
+ ) * self.scale
+ attn = self.softmax(attn)
+ attn = self.attn_drop(attn)
+
+ x = self.transpose(self.matmul(attn, v), (0, 2, 1, 3))
+ x = self.reshape(x, (b, n, -1))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Mlp(nn.Cell):
+ """
+ Multi-layer perceptron
+
+ Args:
+ in_features(int): Number of input features
+ hidden_features(int): Number of hidden features
+ out_features(int): Number of output features
+ act_layer(class): Activation layer (base class)
+ drop(float): Dropout rate
+ """
+
+ def __init__(self, in_features, hidden_features=None,
+ out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=True)
+ self.act = act_layer()
+ self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=True)
+ self.drop = nn.Dropout(keep_prob=1.0 - drop) # if drop > 0. else Identity()
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SE(nn.Cell):
+ """SE Block"""
+
+ def __init__(self, dim, hidden_ratio=None):
+ super().__init__()
+ hidden_ratio = hidden_ratio or 1
+ self.dim = dim
+ hidden_dim = int(dim * hidden_ratio)
+ self.fc = nn.SequentialCell([
+ nn.LayerNorm(normalized_shape=dim, epsilon=1e-5),
+ nn.Dense(in_channels=dim, out_channels=hidden_dim),
+ nn.ReLU(),
+ nn.Dense(in_channels=hidden_dim, out_channels=dim),
+ nn.Tanh()
+ ])
+
+ self.reduce_mean = P.ReduceMean()
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ a = self.reduce_mean(True, x, 1) # B, 1, C
+ a = self.fc(a)
+ x = a * x
+ return x
+
+
+class Block(nn.Cell):
+ """
+ TNT base block
+
+ Args:
+ outer_dim(int): Number of output features
+ inner_dim(int): Number of internal features
+ outer_num_heads(int): Number of output heads
+ inner_num_heads(int): Number of internal heads
+ num_words(int): Number of 'visual words' (feature groups)
+ mlp_ratio(float): Rate of MLP per hidden features
+ qkv_bias(bool): Use Qk / v bias
+ qk_scale(float): Qk scale
+ drop(float): Dropout rate
+ attn_drop(float): Dropout rate of attention layer
+ drop_path(float): Path dropout rate
+ act_layer(class): Activation layer (class)
+ norm_layer(class): Normalization layer
+ se(int): SE parameter
+ """
+
+ def __init__(self, outer_dim, inner_dim, outer_num_heads,
+ inner_num_heads, num_words, mlp_ratio=4.,
+ qkv_bias=False, qk_scale=None, drop=0.,
+ attn_drop=0., drop_path=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, se=0):
+ super().__init__()
+ self.has_inner = inner_dim > 0
+ if self.has_inner:
+ # Inner
+ self.inner_norm1 = norm_layer((inner_dim,), epsilon=1e-5)
+ self.inner_attn = Attention(
+ inner_dim, inner_dim, num_heads=inner_num_heads, qkv_bias=qkv_bias,
+ qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ self.inner_norm2 = norm_layer((inner_dim,), epsilon=1e-5)
+ self.inner_mlp = Mlp(in_features=inner_dim, hidden_features=int(inner_dim * mlp_ratio),
+ out_features=inner_dim, act_layer=act_layer, drop=drop)
+
+ self.proj_norm1 = norm_layer((num_words * inner_dim,), epsilon=1e-5)
+ self.proj = nn.Dense(in_channels=num_words * inner_dim, out_channels=outer_dim, has_bias=False)
+ self.proj_norm2 = norm_layer((outer_dim,), epsilon=1e-5)
+ # Outer
+ self.outer_norm1 = norm_layer((outer_dim,), epsilon=1e-5)
+ self.outer_attn = Attention(
+ outer_dim, outer_dim, num_heads=outer_num_heads, qkv_bias=qkv_bias,
+ qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath1D(drop_path)
+ self.outer_norm2 = norm_layer((outer_dim,), epsilon=1e-5)
+ self.outer_mlp = Mlp(in_features=outer_dim, hidden_features=int(outer_dim * mlp_ratio),
+ out_features=outer_dim, act_layer=act_layer, drop=drop)
+ # SE
+ self.se = se
+ self.se_layer = 0
+ if self.se > 0:
+ self.se_layer = SE(outer_dim, 0.25)
+ self.zeros = Tensor(np.zeros([1, 1, 1]), dtype=mstype.float32)
+
+ self.reshape = P.Reshape()
+ self.cast = P.Cast()
+
+ def construct(self, *inputs, **kwargs):
+ """TNT Block construct"""
+
+ inner_tokens, outer_tokens = inputs[0], inputs[1]
+ if self.has_inner:
+ in1 = self.inner_norm1(inner_tokens)
+ attn1 = self.inner_attn(in1)
+ inner_tokens = inner_tokens + self.drop_path(attn1) # B*N, k*k, c
+ in2 = self.inner_norm2(inner_tokens)
+ mlp = self.inner_mlp(in2)
+ inner_tokens = inner_tokens + self.drop_path(mlp) # B*N, k*k, c
+ b, n, _ = P.Shape()(outer_tokens)
+ # zeros = P.Tile()(self.zeros, (B, 1, C))
+ proj = self.proj_norm2(self.proj(self.proj_norm1(
+ self.reshape(inner_tokens, (b, n - 1, -1,))
+ )))
+ proj = self.cast(proj, mstype.float32)
+ # proj = P.Concat(1)((zeros, proj))
+ # outer_tokens = outer_tokens + proj # B, N, C
+ outer_tokens[:, 1:] = outer_tokens[:, 1:] + proj
+ if self.se > 0:
+ outer_tokens = outer_tokens + self.drop_path(
+ self.outer_attn(self.outer_norm1(outer_tokens)))
+ tmp_ = self.outer_mlp(self.outer_norm2(outer_tokens))
+ outer_tokens = outer_tokens + self.drop_path(
+ tmp_ + self.se_layer(tmp_))
+ else:
+ outer_tokens = outer_tokens + self.drop_path(
+ self.outer_attn(self.outer_norm1(outer_tokens)))
+ outer_tokens = outer_tokens + self.drop_path(
+ self.outer_mlp(self.outer_norm2(outer_tokens)))
+ return inner_tokens, outer_tokens
+
+
+class TNT(nn.Cell):
+ """
+ TNT (Transformer in Transformer) for computer vision
+
+ Args:
+ img_size(int): Image size (side, px)
+ patch_size(int): Patch size (side, px)
+ in_chans(int): Number of input channels
+ num_classes(int): Number of output classes
+ outer_dim(int): Number of output features
+ inner_dim(int): Number of internal features
+ depth(int): Number of TNT base blocks
+ outer_num_heads(int): Number of output heads
+ inner_num_heads(int): Number of internal heads
+ mlp_ratio(float): Rate of MLP per hidden features
+ qkv_bias(bool): Use Qk / v bias
+ qk_scale(float): Qk scale
+ drop_rate(float): Dropout rate
+ attn_drop_rate(float): Dropout rate for attention layer
+ drop_path_rate(float): Dropout rate for DropPath layer
+ norm_layer(class): Normalization layer
+ inner_stride(int): Number of strides for internal patches
+ se(int): SE parameter
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
+ num_classes=1000, outer_dim=768, inner_dim=48,
+ depth=12, outer_num_heads=12, inner_num_heads=4,
+ mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
+ norm_layer=nn.LayerNorm, inner_stride=4, se=0,
+ **kwargs):
+ super().__init__()
+ _ = kwargs
+ self.num_classes = num_classes
+ self.outer_dim = outer_dim
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, outer_dim=outer_dim,
+ inner_dim=inner_dim, inner_stride=inner_stride)
+ self.num_patches = num_patches = self.patch_embed.num_patches
+ num_words = self.patch_embed.num_words
+
+ self.proj_norm1 = norm_layer((num_words * inner_dim,), epsilon=1e-5)
+ self.proj = nn.Dense(in_channels=num_words * inner_dim, out_channels=outer_dim, has_bias=True)
+ self.proj_norm2 = norm_layer((outer_dim,), epsilon=1e-5)
+
+ self.cls_token = Parameter(Tensor(trunc_array([1, 1, outer_dim]), dtype=mstype.float32), name="cls_token",
+ requires_grad=True)
+ self.outer_pos = Parameter(Tensor(trunc_array([1, num_patches + 1, outer_dim]), dtype=mstype.float32),
+ name="outer_pos")
+ self.inner_pos = Parameter(Tensor(trunc_array([1, num_words, inner_dim]), dtype=mstype.float32))
+ self.pos_drop = nn.Dropout(keep_prob=1.0 - drop_rate)
+
+ dpr = [x for x in np.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ vanilla_idxs = []
+ blocks = []
+ for i in range(depth):
+ if i in vanilla_idxs:
+ blocks.append(Block(
+ outer_dim=outer_dim, inner_dim=-1, outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads,
+ num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, se=se))
+ else:
+ blocks.append(Block(
+ outer_dim=outer_dim, inner_dim=inner_dim, outer_num_heads=outer_num_heads,
+ inner_num_heads=inner_num_heads,
+ num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, se=se))
+ self.blocks = nn.CellList(blocks)
+ # self.norm = norm_layer(outer_dim, eps=1e-5)
+ self.norm = norm_layer((outer_dim,))
+
+ # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
+ # self.repr = nn.Linear(outer_dim, representation_size)
+ # self.repr_act = nn.Tanh()
+
+ # Classifier head
+ mask = np.zeros([1, num_patches + 1, 1])
+ mask[:, 0] = 1
+ self.mask = Tensor(mask, dtype=mstype.float32)
+ self.head = nn.Dense(in_channels=outer_dim, out_channels=num_classes, has_bias=True)
+
+ self.reshape = P.Reshape()
+ self.concat = P.Concat(1)
+ self.tile = P.Tile()
+ self.cast = P.Cast()
+
+ self.init_weights()
+ print("================================success================================")
+
+ def init_weights(self):
+ """init_weights"""
+ for _, cell in self.cells_and_names():
+ if isinstance(cell, nn.Dense):
+ cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02),
+ cell.weight.shape,
+ cell.weight.dtype))
+ if isinstance(cell, nn.Dense) and cell.bias is not None:
+ cell.bias.set_data(weight_init.initializer(weight_init.Zero(),
+ cell.bias.shape,
+ cell.bias.dtype))
+ elif isinstance(cell, nn.LayerNorm):
+ cell.gamma.set_data(weight_init.initializer(weight_init.One(),
+ cell.gamma.shape,
+ cell.gamma.dtype))
+ cell.beta.set_data(weight_init.initializer(weight_init.Zero(),
+ cell.beta.shape,
+ cell.beta.dtype))
+
+ def forward_features(self, x):
+ """TNT forward_features"""
+ b = x.shape[0]
+ inner_tokens = self.patch_embed(x) + self.inner_pos # B*N, 8*8, C
+
+ outer_tokens = self.proj_norm2(
+ self.proj(self.proj_norm1(
+ self.reshape(inner_tokens, (b, self.num_patches, -1,))
+ ))
+ )
+ outer_tokens = self.cast(outer_tokens, mstype.float32)
+ outer_tokens = self.concat((
+ self.tile(self.cls_token, (b, 1, 1)), outer_tokens
+ ))
+
+ outer_tokens = outer_tokens + self.outer_pos
+ outer_tokens = self.pos_drop(outer_tokens)
+
+ for blk in self.blocks:
+ inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens)
+
+ outer_tokens = self.norm(outer_tokens) # [batch_size, num_patch+1, outer_dim)
+ return outer_tokens[:, 0]
+
+ def construct(self, *inputs, **kwargs):
+ x = inputs[0]
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+@register_model
+def tnt_small(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs):
+ """tnt_s_patch16_224"""
+
+ patch_size = 16
+ inner_stride = 4
+ outer_dim = 384
+ inner_dim = 24
+ outer_num_heads = 6
+ inner_num_heads = 4
+ depth = 12
+ num_classes = num_classes
+ outer_dim = make_divisible(outer_dim, outer_num_heads)
+ inner_dim = make_divisible(inner_dim, inner_num_heads)
+ model = TNT(patch_size=patch_size, in_chans=in_channels, num_classes=num_classes,
+ outer_dim=outer_dim, inner_dim=inner_dim, depth=depth,
+ outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads, qkv_bias=False,
+ inner_stride=inner_stride, **kwargs)
+ default_cfg = default_cfgs["tnt_small"]
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model
+
+
+@register_model
+def tnt_base(pretrained: bool = False, num_classes=1000, in_channels=3, **kwargs):
+ """tnt_b_patch16_224"""
+
+ patch_size = 16
+ inner_stride = 4
+ outer_dim = 640
+ inner_dim = 40
+ outer_num_heads = 10
+ inner_num_heads = 4
+ depth = 12
+ num_classes = num_classes
+ outer_dim = make_divisible(outer_dim, outer_num_heads)
+ inner_dim = make_divisible(inner_dim, inner_num_heads)
+ model = TNT(patch_size=patch_size, in_chans=in_channels, num_classes=num_classes,
+ outer_dim=outer_dim, inner_dim=inner_dim, depth=depth,
+ outer_num_heads=outer_num_heads, inner_num_heads=inner_num_heads, qkv_bias=False,
+ inner_stride=inner_stride, **kwargs)
+ default_cfg = default_cfgs["tnt_base"]
+
+ if pretrained:
+ load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
+
+ return model