Skip to content

Commit 1739a2c

Browse files
committed
feat: change ops to mint interfaces
1 parent d1e2dbe commit 1739a2c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1410
-1235
lines changed

mindcv/models/convnext.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import mindspore.common.initializer as init
1111
from mindspore import Parameter, Tensor
1212
from mindspore import dtype as mstype
13-
from mindspore import nn, ops
13+
from mindspore import mint, nn, ops
1414

1515
from .helpers import build_model_with_cfg
1616
from .layers.drop_path import DropPath
@@ -69,15 +69,14 @@ def __init__(self, dim: int):
6969
super().__init__()
7070
self.gamma = Parameter(Tensor(np.zeros([1, 1, 1, dim]), mstype.float32))
7171
self.beta = Parameter(Tensor(np.zeros([1, 1, 1, dim]), mstype.float32))
72-
self.norm = ops.LpNorm(axis=[1, 2], p=2, keep_dims=True)
7372

7473
def construct(self, x: Tensor) -> Tensor:
75-
gx = self.norm(x)
76-
nx = gx / (ops.mean(gx, axis=-1, keep_dims=True) + 1e-6)
74+
gx = mint.norm(x, p=2, dim=(1, 2), keepdim=True)
75+
nx = gx / (mint.mean(gx, dim=-1, keepdim=True) + 1e-6)
7776
return self.gamma * (x * nx) + self.beta + x
7877

7978

80-
class ConvNextLayerNorm(nn.LayerNorm):
79+
class ConvNextLayerNorm(mint.nn.LayerNorm):
8180
"""
8281
LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
8382
"""
@@ -88,17 +87,17 @@ def __init__(
8887
epsilon: float,
8988
norm_axis: int = -1,
9089
) -> None:
91-
super().__init__(normalized_shape=normalized_shape, epsilon=epsilon)
90+
super().__init__(normalized_shape=normalized_shape, eps=epsilon)
9291
assert norm_axis in (-1, 1), "ConvNextLayerNorm's norm_axis must be 1 or -1."
9392
self.norm_axis = norm_axis
9493

9594
def construct(self, input_x: Tensor) -> Tensor:
9695
if self.norm_axis == -1:
97-
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
96+
y = ops.layer_norm(input_x, self.normalized_shape, self.weight, self.bias, self.eps)
9897
else:
99-
input_x = ops.transpose(input_x, (0, 2, 3, 1))
100-
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
101-
y = ops.transpose(y, (0, 3, 1, 2))
98+
input_x = mint.permute(input_x, (0, 2, 3, 1))
99+
y = ops.layer_norm(input_x, self.normalized_shape, self.weight, self.bias, self.eps)
100+
y = mint.permute(y, (0, 3, 1, 2))
102101
return y
103102

104103

@@ -124,22 +123,22 @@ def __init__(
124123
use_grn: bool = False,
125124
) -> None:
126125
super().__init__()
127-
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, group=dim, has_bias=True) # depthwise conv
126+
self.dwconv = mint.nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True) # depthwise conv
128127
self.norm = ConvNextLayerNorm((dim,), epsilon=1e-6)
129-
self.pwconv1 = nn.Dense(dim, 4 * dim) # pointwise/1x1 convs, implemented with Dense layers
130-
self.act = nn.GELU()
128+
self.pwconv1 = mint.nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with Dense layers
129+
self.act = mint.nn.GELU()
131130
self.use_grn = use_grn
132131
if use_grn:
133132
self.grn = GRN(4 * dim)
134-
self.pwconv2 = nn.Dense(4 * dim, dim)
133+
self.pwconv2 = mint.nn.Linear(4 * dim, dim)
135134
self.gamma_ = Parameter(Tensor(layer_scale_init_value * np.ones((dim)), dtype=mstype.float32),
136135
requires_grad=True) if layer_scale_init_value > 0 else None
137136
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
138137

139138
def construct(self, x: Tensor) -> Tensor:
140139
downsample = x
141140
x = self.dwconv(x)
142-
x = ops.transpose(x, (0, 2, 3, 1))
141+
x = mint.permute(x, (0, 2, 3, 1))
143142
x = self.norm(x)
144143
x = self.pwconv1(x)
145144
x = self.act(x)
@@ -148,7 +147,7 @@ def construct(self, x: Tensor) -> Tensor:
148147
x = self.pwconv2(x)
149148
if self.gamma_ is not None:
150149
x = self.gamma_ * x
151-
x = ops.transpose(x, (0, 3, 1, 2))
150+
x = mint.permute(x, (0, 3, 1, 2))
152151
x = downsample + self.drop_path(x)
153152
return x
154153

@@ -184,14 +183,14 @@ def __init__(
184183

185184
downsample_layers = [] # stem and 3 intermediate down_sampling conv layers
186185
stem = nn.SequentialCell(
187-
nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4, has_bias=True),
186+
mint.nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4, bias=True),
188187
ConvNextLayerNorm((dims[0],), epsilon=1e-6, norm_axis=1),
189188
)
190189
downsample_layers.append(stem)
191190
for i in range(3):
192191
downsample_layer = nn.SequentialCell(
193192
ConvNextLayerNorm((dims[i],), epsilon=1e-6, norm_axis=1),
194-
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, has_bias=True),
193+
mint.nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, bias=True),
195194
)
196195
downsample_layers.append(downsample_layer)
197196

@@ -226,18 +225,18 @@ def __init__(
226225
stages[3]
227226
])
228227
self.norm = ConvNextLayerNorm((dims[-1],), epsilon=1e-6) # final norm layer
229-
self.classifier = nn.Dense(dims[-1], num_classes) # classifier
228+
self.classifier = mint.nn.Linear(dims[-1], num_classes) # classifier
230229
self.head_init_scale = head_init_scale
231230
self._initialize_weights()
232231

233232
def _initialize_weights(self) -> None:
234233
"""Initialize weights for cells."""
235234
for _, cell in self.cells_and_names():
236-
if isinstance(cell, (nn.Dense, nn.Conv2d)):
235+
if isinstance(cell, (mint.nn.Linear, mint.nn.Conv2d)):
237236
cell.weight.set_data(
238237
init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)
239238
)
240-
if isinstance(cell, nn.Dense) and cell.bias is not None:
239+
if isinstance(cell, mint.nn.Linear) and cell.bias is not None:
241240
cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype))
242241
self.classifier.weight.set_data(self.classifier.weight * self.head_init_scale)
243242
self.classifier.bias.set_data(self.classifier.bias * self.head_init_scale)

mindcv/models/densenet.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from typing import Tuple
99

1010
import mindspore.common.initializer as init
11-
from mindspore import Tensor, nn, ops
11+
from mindspore import Tensor, mint, nn
1212

1313
from .helpers import load_pretrained
14-
from .layers.compatibility import Dropout
1514
from .layers.pooling import GlobalAvgPooling
1615
from .registry import register_model
1716

@@ -53,16 +52,17 @@ def __init__(
5352
drop_rate: float,
5453
) -> None:
5554
super().__init__()
56-
self.norm1 = nn.BatchNorm2d(num_input_features)
57-
self.relu1 = nn.ReLU()
58-
self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1)
55+
self.norm1 = mint.nn.BatchNorm2d(num_input_features)
56+
self.relu1 = mint.nn.ReLU()
57+
self.conv1 = mint.nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
5958

60-
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
61-
self.relu2 = nn.ReLU()
62-
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, pad_mode="pad", padding=1)
59+
self.norm2 = mint.nn.BatchNorm2d(bn_size * growth_rate)
60+
self.relu2 = mint.nn.ReLU()
61+
self.conv2 = mint.nn.Conv2d(
62+
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
6363

6464
self.drop_rate = drop_rate
65-
self.dropout = Dropout(p=self.drop_rate)
65+
self.dropout = mint.nn.Dropout(p=self.drop_rate)
6666

6767
def construct(self, features: Tensor) -> Tensor:
6868
bottleneck = self.conv1(self.relu1(self.norm1(features)))
@@ -98,7 +98,7 @@ def construct(self, init_features: Tensor) -> Tensor:
9898
features = init_features
9999
for layer in self.cell_list:
100100
new_features = layer(features)
101-
features = ops.concat((features, new_features), axis=1)
101+
features = mint.concat((features, new_features), dim=1)
102102
return features
103103

104104

@@ -112,10 +112,10 @@ def __init__(
112112
) -> None:
113113
super().__init__()
114114
self.features = nn.SequentialCell(OrderedDict([
115-
("norm", nn.BatchNorm2d(num_input_features)),
116-
("relu", nn.ReLU()),
117-
("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1)),
118-
("pool", nn.AvgPool2d(kernel_size=2, stride=2))
115+
("norm", mint.nn.BatchNorm2d(num_input_features)),
116+
("relu", mint.nn.ReLU()),
117+
("conv", mint.nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)),
118+
("pool", mint.nn.AvgPool2d(kernel_size=2, stride=2))
119119
]))
120120

121121
def construct(self, x: Tensor) -> Tensor:
@@ -152,13 +152,11 @@ def __init__(
152152
layers = OrderedDict()
153153
# first Conv2d
154154
num_features = num_init_features
155-
layers["conv0"] = nn.Conv2d(in_channels, num_features, kernel_size=7, stride=2, pad_mode="pad", padding=3)
156-
layers["norm0"] = nn.BatchNorm2d(num_features)
157-
layers["relu0"] = nn.ReLU()
158-
layers["pool0"] = nn.SequentialCell([
159-
nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"),
160-
nn.MaxPool2d(kernel_size=3, stride=2),
161-
])
155+
layers["conv0"] = mint.nn.Conv2d(
156+
in_channels, num_features, kernel_size=7, stride=2, padding=3, bias=False)
157+
layers["norm0"] = mint.nn.BatchNorm2d(num_features)
158+
layers["relu0"] = mint.nn.ReLU()
159+
layers["pool0"] = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
162160

163161
# DenseBlock
164162
for i, num_layers in enumerate(block_config):
@@ -177,30 +175,30 @@ def __init__(
177175
num_features = num_features // 2
178176

179177
# final bn+ReLU
180-
layers["norm5"] = nn.BatchNorm2d(num_features)
181-
layers["relu5"] = nn.ReLU()
178+
layers["norm5"] = mint.nn.BatchNorm2d(num_features)
179+
layers["relu5"] = mint.nn.ReLU()
182180

183181
self.num_features = num_features
184182
self.features = nn.SequentialCell(layers)
185183
self.pool = GlobalAvgPooling()
186-
self.classifier = nn.Dense(self.num_features, num_classes)
184+
self.classifier = mint.nn.Linear(self.num_features, num_classes)
187185
self._initialize_weights()
188186

189187
def _initialize_weights(self) -> None:
190188
"""Initialize weights for cells."""
191189
for _, cell in self.cells_and_names():
192-
if isinstance(cell, nn.Conv2d):
190+
if isinstance(cell, mint.nn.Conv2d):
193191
cell.weight.set_data(
194192
init.initializer(init.HeNormal(math.sqrt(5), mode="fan_out", nonlinearity="relu"),
195193
cell.weight.shape, cell.weight.dtype))
196194
if cell.bias is not None:
197195
cell.bias.set_data(
198196
init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"),
199197
cell.bias.shape, cell.bias.dtype))
200-
elif isinstance(cell, nn.BatchNorm2d):
201-
cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype))
202-
cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype))
203-
elif isinstance(cell, nn.Dense):
198+
elif isinstance(cell, mint.nn.BatchNorm2d):
199+
cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype))
200+
cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype))
201+
elif isinstance(cell, mint.nn.Linear):
204202
cell.weight.set_data(
205203
init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"),
206204
cell.weight.shape, cell.weight.dtype))

mindcv/models/googlenet.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from typing import Tuple, Union
88

99
import mindspore.common.initializer as init
10-
from mindspore import Tensor, nn, ops
10+
from mindspore import Tensor, mint, nn
1111

1212
from .helpers import load_pretrained
13-
from .layers.compatibility import Dropout
13+
from .layers.flatten import Flatten
1414
from .layers.pooling import GlobalAvgPooling
1515
from .registry import register_model
1616

@@ -45,12 +45,12 @@ def __init__(
4545
kernel_size: int = 1,
4646
stride: int = 1,
4747
padding: int = 0,
48-
pad_mode: str = "same",
48+
pad_mode: str = "zeros",
4949
) -> None:
5050
super().__init__()
51-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
52-
padding=padding, pad_mode=pad_mode)
53-
self.relu = nn.ReLU()
51+
self.conv = mint.nn.Conv2d(
52+
in_channels, out_channels, kernel_size, stride, padding=padding, padding_mode=pad_mode, bias=False)
53+
self.relu = mint.nn.ReLU()
5454

5555
def construct(self, x: Tensor) -> Tensor:
5656
x = self.conv(x)
@@ -75,14 +75,14 @@ def __init__(
7575
self.b1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
7676
self.b2 = nn.SequentialCell([
7777
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
78-
BasicConv2d(ch3x3red, ch3x3, kernel_size=3),
78+
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),
7979
])
8080
self.b3 = nn.SequentialCell([
8181
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
82-
BasicConv2d(ch5x5red, ch5x5, kernel_size=5),
82+
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2),
8383
])
8484
self.b4 = nn.SequentialCell([
85-
nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same"),
85+
mint.nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
8686
BasicConv2d(in_channels, pool_proj, kernel_size=1),
8787
])
8888

@@ -91,7 +91,7 @@ def construct(self, x: Tensor) -> Tensor:
9191
branch2 = self.b2(x)
9292
branch3 = self.b3(x)
9393
branch4 = self.b4(x)
94-
return ops.concat((branch1, branch2, branch3, branch4), axis=1)
94+
return mint.concat((branch1, branch2, branch3, branch4), dim=1)
9595

9696

9797
class InceptionAux(nn.Cell):
@@ -104,13 +104,13 @@ def __init__(
104104
drop_rate: float = 0.7,
105105
) -> None:
106106
super().__init__()
107-
self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3)
107+
self.avg_pool = mint.nn.AvgPool2d(kernel_size=5, stride=3)
108108
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
109-
self.fc1 = nn.Dense(2048, 1024)
110-
self.fc2 = nn.Dense(1024, num_classes)
111-
self.flatten = nn.Flatten()
112-
self.relu = nn.ReLU()
113-
self.dropout = Dropout(p=drop_rate)
109+
self.fc1 = mint.nn.Linear(2048, 1024)
110+
self.fc2 = mint.nn.Linear(1024, num_classes)
111+
self.flatten = Flatten()
112+
self.relu = mint.nn.ReLU()
113+
self.dropout = mint.nn.Dropout(p=drop_rate)
114114

115115
def construct(self, x: Tensor) -> Tensor:
116116
x = self.avg_pool(x)
@@ -145,23 +145,23 @@ def __init__(
145145
) -> None:
146146
super().__init__()
147147
self.aux_logits = aux_logits
148-
self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2)
149-
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
148+
self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2, padding=3)
149+
self.maxpool1 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
150150

151151
self.conv2 = BasicConv2d(64, 64, kernel_size=1)
152-
self.conv3 = BasicConv2d(64, 192, kernel_size=3)
153-
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
152+
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
153+
self.maxpool2 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
154154

155155
self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
156156
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
157-
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
157+
self.maxpool3 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
158158

159159
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
160160
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
161161
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
162162
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
163163
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
164-
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
164+
self.maxpool4 = mint.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
165165

166166
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
167167
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
@@ -171,22 +171,24 @@ def __init__(
171171
self.aux2 = InceptionAux(528, num_classes, drop_rate=drop_rate_aux)
172172

173173
self.pool = GlobalAvgPooling()
174-
self.dropout = Dropout(p=drop_rate)
175-
self.classifier = nn.Dense(1024, num_classes)
174+
self.dropout = mint.nn.Dropout(p=drop_rate)
175+
self.classifier = mint.nn.Linear(1024, num_classes)
176176
self._initialize_weights()
177177

178178
def _initialize_weights(self):
179179
for _, cell in self.cells_and_names():
180-
if isinstance(cell, nn.Conv2d):
180+
if isinstance(cell, mint.nn.Conv2d):
181181
cell.weight.set_data(init.initializer(init.HeNormal(0, mode='fan_in', nonlinearity='leaky_relu'),
182182
cell.weight.shape, cell.weight.dtype))
183183
if cell.bias is not None:
184184
cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))
185-
elif isinstance(cell, nn.BatchNorm2d) or isinstance(cell, nn.BatchNorm1d):
186-
cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape, cell.gamma.dtype))
187-
if cell.beta is not None:
188-
cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape, cell.gamma.dtype))
189-
elif isinstance(cell, nn.Dense):
185+
elif isinstance(cell, mint.nn.BatchNorm2d) or isinstance(cell, mint.nn.BatchNorm1d):
186+
cell.weight.set_data(
187+
init.initializer(init.Constant(1), cell.weight.shape, cell.weight.dtype))
188+
if cell.bias is not None:
189+
cell.bias.set_data(
190+
init.initializer(init.Constant(0), cell.bias.shape, cell.weight.dtype))
191+
elif isinstance(cell, mint.nn.Linear):
190192
cell.weight.set_data(
191193
init.initializer(init.HeUniform(math.sqrt(5), mode='fan_in', nonlinearity='leaky_relu'),
192194
cell.weight.shape, cell.weight.dtype))

0 commit comments

Comments
 (0)