Skip to content

Commit 5fd01e2

Browse files
sageyouSamitHuang
andauthored
extend vit and add mae model and finetune checkpoint file (#707)
* Refractor ViT to support relative positional embedding and layer scale; Checkpoint updated * fix format * undo attention * add model mae and fintune checkpoint file * extend vit and mae --------- Co-authored-by: samithuang <285365963@qq.com>
1 parent d714673 commit 5fd01e2

File tree

9 files changed

+969
-778
lines changed

9 files changed

+969
-778
lines changed

configs/vit/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ Our reproduced model performance on ImageNet-1K is reported as follows.
3636

3737
| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
3838
|--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
39-
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-7553218f.ckpt) |
40-
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-f02b2487.ckpt) |
41-
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-3a961018.ckpt) |
39+
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) |
40+
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) |
41+
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) |
4242

4343
</div>
4444

mindcv/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
inceptionv3,
1919
inceptionv4,
2020
layers,
21+
mae,
2122
mixnet,
2223
mlpmixer,
2324
mnasnet,
@@ -74,6 +75,7 @@
7475
from .inceptionv3 import *
7576
from .inceptionv4 import *
7677
from .layers import *
78+
from .mae import *
7779
from .mixnet import *
7880
from .mlpmixer import *
7981
from .mnasnet import *
@@ -132,6 +134,7 @@
132134
__all__.extend(["InceptionV3", "inception_v3"])
133135
__all__.extend(["InceptionV4", "inception_v4"])
134136
__all__.extend(layers.__all__)
137+
__all__.extend(mae.__all__)
135138
__all__.extend(mixnet.__all__)
136139
__all__.extend(mlpmixer.__all__)
137140
__all__.extend(mnasnet.__all__)

mindcv/models/layers/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
"""layers init"""
2-
from . import activation, conv_norm_act, drop_path, identity, pooling, selective_kernel, squeeze_excite
2+
from . import (
3+
activation,
4+
conv_norm_act,
5+
drop_path,
6+
format,
7+
identity,
8+
patch_dropout,
9+
pooling,
10+
pos_embed,
11+
selective_kernel,
12+
squeeze_excite,
13+
)
314
from .activation import *
415
from .conv_norm_act import *
516
from .drop_path import *
17+
from .format import *
618
from .identity import *
19+
from .patch_dropout import *
720
from .pooling import *
21+
from .pos_embed import *
822
from .selective_kernel import *
923
from .squeeze_excite import *
1024

mindcv/models/layers/format.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from enum import Enum
2+
from typing import Union
3+
4+
import mindspore
5+
6+
7+
class Format(str, Enum):
8+
NCHW = 'NCHW'
9+
NHWC = 'NHWC'
10+
NCL = 'NCL'
11+
NLC = 'NLC'
12+
13+
14+
FormatT = Union[str, Format]
15+
16+
17+
def nchw_to(x: mindspore.Tensor, fmt: Format):
18+
if fmt == Format.NHWC:
19+
x = x.permute(0, 2, 3, 1)
20+
elif fmt == Format.NLC:
21+
x = x.flatten(start_dim=2).transpose((0, 2, 1))
22+
elif fmt == Format.NCL:
23+
x = x.flatten(start_dim=2)
24+
return x
25+
26+
27+
def nhwc_to(x: mindspore.Tensor, fmt: Format):
28+
if fmt == Format.NCHW:
29+
x = x.permute(0, 3, 1, 2)
30+
elif fmt == Format.NLC:
31+
x = x.flatten(start_dim=1, end_dim=2)
32+
elif fmt == Format.NCL:
33+
x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1))
34+
return x

mindcv/models/layers/patch_dropout.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
3+
import mindspore as ms
4+
from mindspore import nn, ops
5+
6+
7+
class PatchDropout(nn.Cell):
8+
"""
9+
https://arxiv.org/abs/2212.00794
10+
"""
11+
def __init__(
12+
self,
13+
prob: float = 0.5,
14+
num_prefix_tokens: int = 1,
15+
ordered: bool = False,
16+
return_indices: bool = False,
17+
):
18+
super().__init__()
19+
assert 0 <= prob < 1.
20+
self.prob = prob
21+
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
22+
self.ordered = ordered
23+
self.return_indices = return_indices
24+
self.sort = ops.Sort()
25+
26+
def forward(self, x):
27+
if not self.training or self.prob == 0.:
28+
if self.return_indices:
29+
return x, None
30+
return x
31+
32+
if self.num_prefix_tokens:
33+
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
34+
else:
35+
prefix_tokens = None
36+
37+
B = x.shape[0]
38+
L = x.shape[1]
39+
num_keep = max(1, int(L * (1. - self.prob)))
40+
_, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32))
41+
keep_indices = indices[:, :num_keep]
42+
if self.ordered:
43+
# NOTE does not need to maintain patch order in typical transformer use,
44+
# but possibly useful for debug / visualization
45+
keep_indices, _ = self.sort(keep_indices)
46+
keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2]))
47+
x = ops.gather_elements(x, dim=1, index=keep_indices)
48+
49+
if prefix_tokens is not None:
50+
x = ops.concat((prefix_tokens, x), axis=1)
51+
52+
if self.return_indices:
53+
return x, keep_indices
54+
return x

mindcv/models/layers/patch_embed.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from mindspore import Tensor, nn, ops
66

7+
from .format import Format, nchw_to
78
from .helpers import to_2tuple
89

910

@@ -17,29 +18,45 @@ class PatchEmbed(nn.Cell):
1718
embed_dim (int): Number of linear projection output channels. Default: 96.
1819
norm_layer (nn.Cell, optional): Normalization layer. Default: None
1920
"""
21+
output_fmt: Format
2022

2123
def __init__(
2224
self,
23-
image_size: int = 224,
25+
image_size: Optional[int] = 224,
2426
patch_size: int = 4,
2527
in_chans: int = 3,
2628
embed_dim: int = 96,
2729
norm_layer: Optional[nn.Cell] = None,
30+
flatten: bool = True,
31+
output_fmt: Optional[str] = None,
32+
bias: bool = True,
33+
strict_img_size: bool = True,
34+
dynamic_img_pad: bool = False,
2835
) -> None:
2936
super().__init__()
30-
image_size = to_2tuple(image_size)
31-
patch_size = to_2tuple(patch_size)
32-
patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
33-
self.image_size = image_size
34-
self.patch_size = patch_size
35-
self.patches_resolution = patches_resolution
36-
self.num_patches = patches_resolution[0] * patches_resolution[1]
37-
38-
self.in_chans = in_chans
37+
self.patch_size = to_2tuple(patch_size)
38+
if image_size is not None:
39+
self.image_size = to_2tuple(image_size)
40+
self.patches_resolution = tuple([s // p for s, p in zip(self.image_size, self.patch_size)])
41+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
42+
else:
43+
self.image_size = None
44+
self.patches_resolution = None
45+
self.num_patches = None
46+
47+
if output_fmt is not None:
48+
self.flatten = False
49+
self.output_fmt = Format(output_fmt)
50+
else:
51+
self.flatten = flatten
52+
self.output_fmt = Format.NCHW
53+
54+
self.strict_img_size = strict_img_size
55+
self.dynamic_img_pad = dynamic_img_pad
3956
self.embed_dim = embed_dim
4057

4158
self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size,
42-
pad_mode='pad', has_bias=True, weight_init="TruncatedNormal")
59+
pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal")
4360

4461
if norm_layer is not None:
4562
if isinstance(embed_dim, int):
@@ -50,11 +67,29 @@ def __init__(
5067

5168
def construct(self, x: Tensor) -> Tensor:
5269
"""docstring"""
53-
B = x.shape[0]
54-
# FIXME look at relaxing size constraints
55-
x = ops.Reshape()(self.proj(x), (B, self.embed_dim, -1)) # B Ph*Pw C
56-
x = ops.Transpose()(x, (0, 2, 1))
70+
B, C, H, W = x.shape
71+
if self.image_size is not None:
72+
if self.strict_img_size:
73+
if (H, W) != (self.image_size[0], self.image_size[1]):
74+
raise ValueError(f"Input height and width ({H},{W}) doesn't match model ({self.image_size[0]},"
75+
f"{self.image_size[1]}).")
76+
elif not self.dynamic_img_pad:
77+
if H % self.patch_size[0] != 0:
78+
raise ValueError(f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).")
79+
if W % self.patch_size[1] != 0:
80+
raise ValueError(f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).")
81+
if self.dynamic_img_pad:
82+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
83+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
84+
x = ops.pad(x, (0, pad_w, 0, pad_h))
5785

86+
# FIXME look at relaxing size constraints
87+
x = self.proj(x)
88+
if self.flatten:
89+
x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C
90+
x = ops.Transpose()(x, (0, 2, 1))
91+
elif self.output_fmt != "NCHW":
92+
x = nchw_to(x, self.output_fmt)
5893
if self.norm is not None:
5994
x = self.norm(x)
6095
return x

mindcv/models/layers/pos_embed.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""positional embedding"""
2+
import math
3+
from typing import List, Optional, Tuple
4+
5+
import numpy as np
6+
7+
import mindspore as ms
8+
from mindspore import Parameter, Tensor, nn, ops
9+
10+
from .compatibility import Interpolate
11+
12+
13+
def resample_abs_pos_embed(
14+
posemb,
15+
new_size: List[int],
16+
old_size: Optional[List[int]] = None,
17+
num_prefix_tokens: int = 1,
18+
interpolation: str = 'nearest',
19+
):
20+
# sort out sizes, assume square if old size not provided
21+
num_pos_tokens = posemb.shape[1]
22+
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
23+
24+
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
25+
return posemb
26+
27+
if old_size is None:
28+
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
29+
old_size = hw, hw
30+
31+
if num_prefix_tokens:
32+
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
33+
else:
34+
posemb_prefix, posemb = None, posemb
35+
36+
# do the interpolation
37+
embed_dim = posemb.shape[-1]
38+
orig_dtype = posemb.dtype
39+
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
40+
interpolate = Interpolate(mode=interpolation, align_corners=True)
41+
posemb = interpolate(posemb, size=new_size)
42+
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
43+
posemb = posemb.astype(orig_dtype)
44+
45+
# add back extra (class, etc) prefix tokens
46+
if posemb_prefix is not None:
47+
posemb = ops.concatcat((posemb_prefix, posemb), axis=1)
48+
49+
return posemb
50+
51+
52+
class RelativePositionBiasWithCLS(nn.Cell):
53+
def __init__(
54+
self,
55+
window_size: Tuple[int],
56+
num_heads: int
57+
):
58+
super(RelativePositionBiasWithCLS, self).__init__()
59+
self.window_size = window_size
60+
self.num_tokens = window_size[0] * window_size[1]
61+
62+
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
63+
# 3: cls to token, token to cls, cls to cls
64+
self.relative_position_bias_table = Parameter(
65+
Tensor(np.zeros((num_relative_distance, num_heads)), dtype=ms.float16)
66+
)
67+
coords_h = np.arange(window_size[0]).reshape(window_size[0], 1).repeat(window_size[1], 1).reshape(1, -1)
68+
coords_w = np.arange(window_size[1]).reshape(1, window_size[1]).repeat(window_size[0], 0).reshape(1, -1)
69+
coords_flatten = np.concatenate([coords_h, coords_w], axis=0) # [2, Wh * Ww]
70+
71+
relative_coords = coords_flatten[:, :, np.newaxis] - coords_flatten[:, np.newaxis, :] # [2, Wh * Ww, Wh * Ww]
72+
relative_coords = relative_coords.transpose(1, 2, 0) # [Wh * Ww, Wh * Ww, 2]
73+
relative_coords[:, :, 0] += window_size[0] - 1
74+
relative_coords[:, :, 1] += window_size[1] - 1
75+
relative_coords[:, :, 0] *= 2 * window_size[0] - 1
76+
77+
relative_position_index = np.zeros((self.num_tokens + 1, self.num_tokens + 1),
78+
dtype=relative_coords.dtype) # [Wh * Ww + 1, Wh * Ww + 1]
79+
relative_position_index[1:, 1:] = relative_coords.sum(-1)
80+
relative_position_index[0, 0:] = num_relative_distance - 3
81+
relative_position_index[0:, 0] = num_relative_distance - 2
82+
relative_position_index[0, 0] = num_relative_distance - 1
83+
relative_position_index = Tensor(relative_position_index.reshape(-1))
84+
85+
self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16)
86+
self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False)
87+
88+
def construct(self):
89+
out = ops.matmul(self.relative_position_index, self.relative_position_bias_table)
90+
out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1))
91+
out = ops.transpose(out, (2, 0, 1))
92+
out = ops.expand_dims(out, 0)
93+
return out

0 commit comments

Comments
 (0)