Skip to content

feat: change ops to mint interfaces #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions mindcv/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import mindspore.common.initializer as init
from mindspore import Parameter, Tensor
from mindspore import dtype as mstype
from mindspore import nn, ops
from mindspore import mint, nn, ops

from .helpers import build_model_with_cfg
from .layers.drop_path import DropPath
Expand Down Expand Up @@ -69,15 +69,14 @@ def __init__(self, dim: int):
super().__init__()
self.gamma = Parameter(Tensor(np.zeros([1, 1, 1, dim]), mstype.float32))
self.beta = Parameter(Tensor(np.zeros([1, 1, 1, dim]), mstype.float32))
self.norm = ops.LpNorm(axis=[1, 2], p=2, keep_dims=True)

def construct(self, x: Tensor) -> Tensor:
gx = self.norm(x)
nx = gx / (ops.mean(gx, axis=-1, keep_dims=True) + 1e-6)
gx = mint.norm(x, p=2, dim=(1, 2), keepdim=True)
nx = gx / (mint.mean(gx, dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * nx) + self.beta + x


class ConvNextLayerNorm(nn.LayerNorm):
class ConvNextLayerNorm(mint.nn.LayerNorm):
"""
LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
Expand All @@ -88,17 +87,17 @@ def __init__(
epsilon: float,
norm_axis: int = -1,
) -> None:
super().__init__(normalized_shape=normalized_shape, epsilon=epsilon)
super().__init__(normalized_shape=normalized_shape, eps=epsilon)
assert norm_axis in (-1, 1), "ConvNextLayerNorm's norm_axis must be 1 or -1."
self.norm_axis = norm_axis

def construct(self, input_x: Tensor) -> Tensor:
if self.norm_axis == -1:
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
y = ops.layer_norm(input_x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
input_x = ops.transpose(input_x, (0, 2, 3, 1))
y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
y = ops.transpose(y, (0, 3, 1, 2))
input_x = mint.permute(input_x, (0, 2, 3, 1))
y = ops.layer_norm(input_x, self.normalized_shape, self.weight, self.bias, self.eps)
y = mint.permute(y, (0, 3, 1, 2))
return y


Expand All @@ -124,22 +123,22 @@ def __init__(
use_grn: bool = False,
) -> None:
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, group=dim, has_bias=True) # depthwise conv
self.dwconv = mint.nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True) # depthwise conv
self.norm = ConvNextLayerNorm((dim,), epsilon=1e-6)
self.pwconv1 = nn.Dense(dim, 4 * dim) # pointwise/1x1 convs, implemented with Dense layers
self.act = nn.GELU()
self.pwconv1 = mint.nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with Dense layers
self.act = mint.nn.GELU()
self.use_grn = use_grn
if use_grn:
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Dense(4 * dim, dim)
self.pwconv2 = mint.nn.Linear(4 * dim, dim)
self.gamma_ = Parameter(Tensor(layer_scale_init_value * np.ones((dim)), dtype=mstype.float32),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()

def construct(self, x: Tensor) -> Tensor:
downsample = x
x = self.dwconv(x)
x = ops.transpose(x, (0, 2, 3, 1))
x = mint.permute(x, (0, 2, 3, 1))
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
Expand All @@ -148,7 +147,7 @@ def construct(self, x: Tensor) -> Tensor:
x = self.pwconv2(x)
if self.gamma_ is not None:
x = self.gamma_ * x
x = ops.transpose(x, (0, 3, 1, 2))
x = mint.permute(x, (0, 3, 1, 2))
x = downsample + self.drop_path(x)
return x

Expand Down Expand Up @@ -184,14 +183,14 @@ def __init__(

downsample_layers = [] # stem and 3 intermediate down_sampling conv layers
stem = nn.SequentialCell(
nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4, has_bias=True),
mint.nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4, bias=True),
ConvNextLayerNorm((dims[0],), epsilon=1e-6, norm_axis=1),
)
downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.SequentialCell(
ConvNextLayerNorm((dims[i],), epsilon=1e-6, norm_axis=1),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, has_bias=True),
mint.nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, bias=True),
)
downsample_layers.append(downsample_layer)

Expand Down Expand Up @@ -226,18 +225,18 @@ def __init__(
stages[3]
])
self.norm = ConvNextLayerNorm((dims[-1],), epsilon=1e-6) # final norm layer
self.classifier = nn.Dense(dims[-1], num_classes) # classifier
self.classifier = mint.nn.Linear(dims[-1], num_classes) # classifier
self.head_init_scale = head_init_scale
self._initialize_weights()

def _initialize_weights(self) -> None:
"""Initialize weights for cells."""
for _, cell in self.cells_and_names():
if isinstance(cell, (nn.Dense, nn.Conv2d)):
if isinstance(cell, (mint.nn.Linear, mint.nn.Conv2d)):
cell.weight.set_data(
init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)
)
if isinstance(cell, nn.Dense) and cell.bias is not None:
if isinstance(cell, mint.nn.Linear) and cell.bias is not None:
cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype))
self.classifier.weight.set_data(self.classifier.weight * self.head_init_scale)
self.classifier.bias.set_data(self.classifier.bias * self.head_init_scale)
Expand Down
56 changes: 27 additions & 29 deletions mindcv/models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from typing import Tuple

import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops
from mindspore import Tensor, mint, nn

from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.pooling import GlobalAvgPooling
from .registry import register_model

Expand Down Expand Up @@ -53,16 +52,17 @@ def __init__(
drop_rate: float,
) -> None:
super().__init__()
self.norm1 = nn.BatchNorm2d(num_input_features)
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1)
self.norm1 = mint.nn.BatchNorm2d(num_input_features)
self.relu1 = mint.nn.ReLU()
self.conv1 = mint.nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)

self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, pad_mode="pad", padding=1)
self.norm2 = mint.nn.BatchNorm2d(bn_size * growth_rate)
self.relu2 = mint.nn.ReLU()
self.conv2 = mint.nn.Conv2d(
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

self.drop_rate = drop_rate
self.dropout = Dropout(p=self.drop_rate)
self.dropout = mint.nn.Dropout(p=self.drop_rate)

def construct(self, features: Tensor) -> Tensor:
bottleneck = self.conv1(self.relu1(self.norm1(features)))
Expand Down Expand Up @@ -98,7 +98,7 @@ def construct(self, init_features: Tensor) -> Tensor:
features = init_features
for layer in self.cell_list:
new_features = layer(features)
features = ops.concat((features, new_features), axis=1)
features = mint.concat((features, new_features), dim=1)
return features


Expand All @@ -112,10 +112,10 @@ def __init__(
) -> None:
super().__init__()
self.features = nn.SequentialCell(OrderedDict([
("norm", nn.BatchNorm2d(num_input_features)),
("relu", nn.ReLU()),
("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1)),
("pool", nn.AvgPool2d(kernel_size=2, stride=2))
("norm", mint.nn.BatchNorm2d(num_input_features)),
("relu", mint.nn.ReLU()),
("conv", mint.nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)),
("pool", mint.nn.AvgPool2d(kernel_size=2, stride=2))
]))

def construct(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -152,13 +152,11 @@ def __init__(
layers = OrderedDict()
# first Conv2d
num_features = num_init_features
layers["conv0"] = nn.Conv2d(in_channels, num_features, kernel_size=7, stride=2, pad_mode="pad", padding=3)
layers["norm0"] = nn.BatchNorm2d(num_features)
layers["relu0"] = nn.ReLU()
layers["pool0"] = nn.SequentialCell([
nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"),
nn.MaxPool2d(kernel_size=3, stride=2),
])
layers["conv0"] = mint.nn.Conv2d(
in_channels, num_features, kernel_size=7, stride=2, padding=3, bias=False)
layers["norm0"] = mint.nn.BatchNorm2d(num_features)
layers["relu0"] = mint.nn.ReLU()
layers["pool0"] = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# DenseBlock
for i, num_layers in enumerate(block_config):
Expand All @@ -177,30 +175,30 @@ def __init__(
num_features = num_features // 2

# final bn+ReLU
layers["norm5"] = nn.BatchNorm2d(num_features)
layers["relu5"] = nn.ReLU()
layers["norm5"] = mint.nn.BatchNorm2d(num_features)
layers["relu5"] = mint.nn.ReLU()

self.num_features = num_features
self.features = nn.SequentialCell(layers)
self.pool = GlobalAvgPooling()
self.classifier = nn.Dense(self.num_features, num_classes)
self.classifier = mint.nn.Linear(self.num_features, num_classes)
self._initialize_weights()

def _initialize_weights(self) -> None:
"""Initialize weights for cells."""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if isinstance(cell, mint.nn.Conv2d):
cell.weight.set_data(
init.initializer(init.HeNormal(math.sqrt(5), mode="fan_out", nonlinearity="relu"),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"),
cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype))
cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype))
elif isinstance(cell, nn.Dense):
elif isinstance(cell, mint.nn.BatchNorm2d):
cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype))
cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, mint.nn.Linear):
cell.weight.set_data(
init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"),
cell.weight.shape, cell.weight.dtype))
Expand Down
62 changes: 32 additions & 30 deletions mindcv/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from typing import Tuple, Union

import mindspore.common.initializer as init
from mindspore import Tensor, nn, ops
from mindspore import Tensor, mint, nn

from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.flatten import Flatten
from .layers.pooling import GlobalAvgPooling
from .registry import register_model

Expand Down Expand Up @@ -45,12 +45,12 @@ def __init__(
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
pad_mode: str = "same",
pad_mode: str = "zeros",
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding=padding, pad_mode=pad_mode)
self.relu = nn.ReLU()
self.conv = mint.nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding=padding, padding_mode=pad_mode, bias=False)
self.relu = mint.nn.ReLU()

def construct(self, x: Tensor) -> Tensor:
x = self.conv(x)
Expand All @@ -75,14 +75,14 @@ def __init__(
self.b1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
self.b2 = nn.SequentialCell([
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),
])
self.b3 = nn.SequentialCell([
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=5),
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2),
])
self.b4 = nn.SequentialCell([
nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same"),
mint.nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
BasicConv2d(in_channels, pool_proj, kernel_size=1),
])

Expand All @@ -91,7 +91,7 @@ def construct(self, x: Tensor) -> Tensor:
branch2 = self.b2(x)
branch3 = self.b3(x)
branch4 = self.b4(x)
return ops.concat((branch1, branch2, branch3, branch4), axis=1)
return mint.concat((branch1, branch2, branch3, branch4), dim=1)


class InceptionAux(nn.Cell):
Expand All @@ -104,13 +104,13 @@ def __init__(
drop_rate: float = 0.7,
) -> None:
super().__init__()
self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3)
self.avg_pool = mint.nn.AvgPool2d(kernel_size=5, stride=3)
self.conv = BasicConv2d(in_channels, 128, kernel_size=1)
self.fc1 = nn.Dense(2048, 1024)
self.fc2 = nn.Dense(1024, num_classes)
self.flatten = nn.Flatten()
self.relu = nn.ReLU()
self.dropout = Dropout(p=drop_rate)
self.fc1 = mint.nn.Linear(2048, 1024)
self.fc2 = mint.nn.Linear(1024, num_classes)
self.flatten = Flatten()
self.relu = mint.nn.ReLU()
self.dropout = mint.nn.Dropout(p=drop_rate)

def construct(self, x: Tensor) -> Tensor:
x = self.avg_pool(x)
Expand Down Expand Up @@ -145,23 +145,23 @@ def __init__(
) -> None:
super().__init__()
self.aux_logits = aux_logits
self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2, padding=3)
self.maxpool1 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

self.conv2 = BasicConv2d(64, 64, kernel_size=1)
self.conv3 = BasicConv2d(64, 192, kernel_size=3)
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
self.maxpool2 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.maxpool3 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)

self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.maxpool4 = mint.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)

self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
Expand All @@ -171,22 +171,24 @@ def __init__(
self.aux2 = InceptionAux(528, num_classes, drop_rate=drop_rate_aux)

self.pool = GlobalAvgPooling()
self.dropout = Dropout(p=drop_rate)
self.classifier = nn.Dense(1024, num_classes)
self.dropout = mint.nn.Dropout(p=drop_rate)
self.classifier = mint.nn.Linear(1024, num_classes)
self._initialize_weights()

def _initialize_weights(self):
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if isinstance(cell, mint.nn.Conv2d):
cell.weight.set_data(init.initializer(init.HeNormal(0, mode='fan_in', nonlinearity='leaky_relu'),
cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype))
elif isinstance(cell, nn.BatchNorm2d) or isinstance(cell, nn.BatchNorm1d):
cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape, cell.gamma.dtype))
if cell.beta is not None:
cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape, cell.gamma.dtype))
elif isinstance(cell, nn.Dense):
elif isinstance(cell, mint.nn.BatchNorm2d) or isinstance(cell, mint.nn.BatchNorm1d):
cell.weight.set_data(
init.initializer(init.Constant(1), cell.weight.shape, cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
init.initializer(init.Constant(0), cell.bias.shape, cell.weight.dtype))
elif isinstance(cell, mint.nn.Linear):
cell.weight.set_data(
init.initializer(init.HeUniform(math.sqrt(5), mode='fan_in', nonlinearity='leaky_relu'),
cell.weight.shape, cell.weight.dtype))
Expand Down
Loading
Loading