diff --git a/mindcv/models/convnext.py b/mindcv/models/convnext.py index f62eca650..47cd3ac21 100644 --- a/mindcv/models/convnext.py +++ b/mindcv/models/convnext.py @@ -10,7 +10,7 @@ import mindspore.common.initializer as init from mindspore import Parameter, Tensor from mindspore import dtype as mstype -from mindspore import nn, ops +from mindspore import mint, nn, ops from .helpers import build_model_with_cfg from .layers.drop_path import DropPath @@ -69,15 +69,14 @@ def __init__(self, dim: int): super().__init__() self.gamma = Parameter(Tensor(np.zeros([1, 1, 1, dim]), mstype.float32)) self.beta = Parameter(Tensor(np.zeros([1, 1, 1, dim]), mstype.float32)) - self.norm = ops.LpNorm(axis=[1, 2], p=2, keep_dims=True) def construct(self, x: Tensor) -> Tensor: - gx = self.norm(x) - nx = gx / (ops.mean(gx, axis=-1, keep_dims=True) + 1e-6) + gx = mint.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (mint.mean(gx, dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * nx) + self.beta + x -class ConvNextLayerNorm(nn.LayerNorm): +class ConvNextLayerNorm(mint.nn.LayerNorm): """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). """ @@ -88,17 +87,17 @@ def __init__( epsilon: float, norm_axis: int = -1, ) -> None: - super().__init__(normalized_shape=normalized_shape, epsilon=epsilon) + super().__init__(normalized_shape=normalized_shape, eps=epsilon) assert norm_axis in (-1, 1), "ConvNextLayerNorm's norm_axis must be 1 or -1." self.norm_axis = norm_axis def construct(self, input_x: Tensor) -> Tensor: if self.norm_axis == -1: - y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) + y = ops.layer_norm(input_x, self.normalized_shape, self.weight, self.bias, self.eps) else: - input_x = ops.transpose(input_x, (0, 2, 3, 1)) - y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) - y = ops.transpose(y, (0, 3, 1, 2)) + input_x = mint.permute(input_x, (0, 2, 3, 1)) + y = ops.layer_norm(input_x, self.normalized_shape, self.weight, self.bias, self.eps) + y = mint.permute(y, (0, 3, 1, 2)) return y @@ -124,14 +123,14 @@ def __init__( use_grn: bool = False, ) -> None: super().__init__() - self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, group=dim, has_bias=True) # depthwise conv + self.dwconv = mint.nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True) # depthwise conv self.norm = ConvNextLayerNorm((dim,), epsilon=1e-6) - self.pwconv1 = nn.Dense(dim, 4 * dim) # pointwise/1x1 convs, implemented with Dense layers - self.act = nn.GELU() + self.pwconv1 = mint.nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with Dense layers + self.act = mint.nn.GELU() self.use_grn = use_grn if use_grn: self.grn = GRN(4 * dim) - self.pwconv2 = nn.Dense(4 * dim, dim) + self.pwconv2 = mint.nn.Linear(4 * dim, dim) self.gamma_ = Parameter(Tensor(layer_scale_init_value * np.ones((dim)), dtype=mstype.float32), requires_grad=True) if layer_scale_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() @@ -139,7 +138,7 @@ def __init__( def construct(self, x: Tensor) -> Tensor: downsample = x x = self.dwconv(x) - x = ops.transpose(x, (0, 2, 3, 1)) + x = mint.permute(x, (0, 2, 3, 1)) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) @@ -148,7 +147,7 @@ def construct(self, x: Tensor) -> Tensor: x = self.pwconv2(x) if self.gamma_ is not None: x = self.gamma_ * x - x = ops.transpose(x, (0, 3, 1, 2)) + x = mint.permute(x, (0, 3, 1, 2)) x = downsample + self.drop_path(x) return x @@ -184,14 +183,14 @@ def __init__( downsample_layers = [] # stem and 3 intermediate down_sampling conv layers stem = nn.SequentialCell( - nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4, has_bias=True), + mint.nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4, bias=True), ConvNextLayerNorm((dims[0],), epsilon=1e-6, norm_axis=1), ) downsample_layers.append(stem) for i in range(3): downsample_layer = nn.SequentialCell( ConvNextLayerNorm((dims[i],), epsilon=1e-6, norm_axis=1), - nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, has_bias=True), + mint.nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, bias=True), ) downsample_layers.append(downsample_layer) @@ -226,18 +225,18 @@ def __init__( stages[3] ]) self.norm = ConvNextLayerNorm((dims[-1],), epsilon=1e-6) # final norm layer - self.classifier = nn.Dense(dims[-1], num_classes) # classifier + self.classifier = mint.nn.Linear(dims[-1], num_classes) # classifier self.head_init_scale = head_init_scale self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, (nn.Dense, nn.Conv2d)): + if isinstance(cell, (mint.nn.Linear, mint.nn.Conv2d)): cell.weight.set_data( init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype) ) - if isinstance(cell, nn.Dense) and cell.bias is not None: + if isinstance(cell, mint.nn.Linear) and cell.bias is not None: cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype)) self.classifier.weight.set_data(self.classifier.weight * self.head_init_scale) self.classifier.bias.set_data(self.classifier.bias * self.head_init_scale) diff --git a/mindcv/models/densenet.py b/mindcv/models/densenet.py index f8b294266..b4a73b630 100644 --- a/mindcv/models/densenet.py +++ b/mindcv/models/densenet.py @@ -8,10 +8,9 @@ from typing import Tuple import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Dropout from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -53,16 +52,17 @@ def __init__( drop_rate: float, ) -> None: super().__init__() - self.norm1 = nn.BatchNorm2d(num_input_features) - self.relu1 = nn.ReLU() - self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1) + self.norm1 = mint.nn.BatchNorm2d(num_input_features) + self.relu1 = mint.nn.ReLU() + self.conv1 = mint.nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False) - self.norm2 = nn.BatchNorm2d(bn_size * growth_rate) - self.relu2 = nn.ReLU() - self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, pad_mode="pad", padding=1) + self.norm2 = mint.nn.BatchNorm2d(bn_size * growth_rate) + self.relu2 = mint.nn.ReLU() + self.conv2 = mint.nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False) self.drop_rate = drop_rate - self.dropout = Dropout(p=self.drop_rate) + self.dropout = mint.nn.Dropout(p=self.drop_rate) def construct(self, features: Tensor) -> Tensor: bottleneck = self.conv1(self.relu1(self.norm1(features))) @@ -98,7 +98,7 @@ def construct(self, init_features: Tensor) -> Tensor: features = init_features for layer in self.cell_list: new_features = layer(features) - features = ops.concat((features, new_features), axis=1) + features = mint.concat((features, new_features), dim=1) return features @@ -112,10 +112,10 @@ def __init__( ) -> None: super().__init__() self.features = nn.SequentialCell(OrderedDict([ - ("norm", nn.BatchNorm2d(num_input_features)), - ("relu", nn.ReLU()), - ("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1)), - ("pool", nn.AvgPool2d(kernel_size=2, stride=2)) + ("norm", mint.nn.BatchNorm2d(num_input_features)), + ("relu", mint.nn.ReLU()), + ("conv", mint.nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)), + ("pool", mint.nn.AvgPool2d(kernel_size=2, stride=2)) ])) def construct(self, x: Tensor) -> Tensor: @@ -152,13 +152,11 @@ def __init__( layers = OrderedDict() # first Conv2d num_features = num_init_features - layers["conv0"] = nn.Conv2d(in_channels, num_features, kernel_size=7, stride=2, pad_mode="pad", padding=3) - layers["norm0"] = nn.BatchNorm2d(num_features) - layers["relu0"] = nn.ReLU() - layers["pool0"] = nn.SequentialCell([ - nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"), - nn.MaxPool2d(kernel_size=3, stride=2), - ]) + layers["conv0"] = mint.nn.Conv2d( + in_channels, num_features, kernel_size=7, stride=2, padding=3, bias=False) + layers["norm0"] = mint.nn.BatchNorm2d(num_features) + layers["relu0"] = mint.nn.ReLU() + layers["pool0"] = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # DenseBlock for i, num_layers in enumerate(block_config): @@ -177,19 +175,19 @@ def __init__( num_features = num_features // 2 # final bn+ReLU - layers["norm5"] = nn.BatchNorm2d(num_features) - layers["relu5"] = nn.ReLU() + layers["norm5"] = mint.nn.BatchNorm2d(num_features) + layers["relu5"] = mint.nn.ReLU() self.num_features = num_features self.features = nn.SequentialCell(layers) self.pool = GlobalAvgPooling() - self.classifier = nn.Dense(self.num_features, num_classes) + self.classifier = mint.nn.Linear(self.num_features, num_classes) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.HeNormal(math.sqrt(5), mode="fan_out", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype)) @@ -197,10 +195,10 @@ def _initialize_weights(self) -> None: cell.bias.set_data( init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"), cell.weight.shape, cell.weight.dtype)) diff --git a/mindcv/models/googlenet.py b/mindcv/models/googlenet.py index 252566365..245ba2436 100644 --- a/mindcv/models/googlenet.py +++ b/mindcv/models/googlenet.py @@ -7,10 +7,10 @@ from typing import Tuple, Union import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Dropout +from .layers.flatten import Flatten from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -45,12 +45,12 @@ def __init__( kernel_size: int = 1, stride: int = 1, padding: int = 0, - pad_mode: str = "same", + pad_mode: str = "zeros", ) -> None: super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, - padding=padding, pad_mode=pad_mode) - self.relu = nn.ReLU() + self.conv = mint.nn.Conv2d( + in_channels, out_channels, kernel_size, stride, padding=padding, padding_mode=pad_mode, bias=False) + self.relu = mint.nn.ReLU() def construct(self, x: Tensor) -> Tensor: x = self.conv(x) @@ -75,14 +75,14 @@ def __init__( self.b1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) self.b2 = nn.SequentialCell([ BasicConv2d(in_channels, ch3x3red, kernel_size=1), - BasicConv2d(ch3x3red, ch3x3, kernel_size=3), + BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1), ]) self.b3 = nn.SequentialCell([ BasicConv2d(in_channels, ch5x5red, kernel_size=1), - BasicConv2d(ch5x5red, ch5x5, kernel_size=5), + BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2), ]) self.b4 = nn.SequentialCell([ - nn.MaxPool2d(kernel_size=3, stride=1, pad_mode="same"), + mint.nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), BasicConv2d(in_channels, pool_proj, kernel_size=1), ]) @@ -91,7 +91,7 @@ def construct(self, x: Tensor) -> Tensor: branch2 = self.b2(x) branch3 = self.b3(x) branch4 = self.b4(x) - return ops.concat((branch1, branch2, branch3, branch4), axis=1) + return mint.concat((branch1, branch2, branch3, branch4), dim=1) class InceptionAux(nn.Cell): @@ -104,13 +104,13 @@ def __init__( drop_rate: float = 0.7, ) -> None: super().__init__() - self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3) + self.avg_pool = mint.nn.AvgPool2d(kernel_size=5, stride=3) self.conv = BasicConv2d(in_channels, 128, kernel_size=1) - self.fc1 = nn.Dense(2048, 1024) - self.fc2 = nn.Dense(1024, num_classes) - self.flatten = nn.Flatten() - self.relu = nn.ReLU() - self.dropout = Dropout(p=drop_rate) + self.fc1 = mint.nn.Linear(2048, 1024) + self.fc2 = mint.nn.Linear(1024, num_classes) + self.flatten = Flatten() + self.relu = mint.nn.ReLU() + self.dropout = mint.nn.Dropout(p=drop_rate) def construct(self, x: Tensor) -> Tensor: x = self.avg_pool(x) @@ -145,23 +145,23 @@ def __init__( ) -> None: super().__init__() self.aux_logits = aux_logits - self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2) - self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2, padding=3) + self.maxpool1 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) self.conv2 = BasicConv2d(64, 64, kernel_size=1) - self.conv3 = BasicConv2d(64, 192, kernel_size=3) - self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) + self.maxpool2 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) - self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.maxpool3 = mint.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) - self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") + self.maxpool4 = mint.nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) @@ -171,22 +171,24 @@ def __init__( self.aux2 = InceptionAux(528, num_classes, drop_rate=drop_rate_aux) self.pool = GlobalAvgPooling() - self.dropout = Dropout(p=drop_rate) - self.classifier = nn.Dense(1024, num_classes) + self.dropout = mint.nn.Dropout(p=drop_rate) + self.classifier = mint.nn.Linear(1024, num_classes) self._initialize_weights() def _initialize_weights(self): for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data(init.initializer(init.HeNormal(0, mode='fan_in', nonlinearity='leaky_relu'), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d) or isinstance(cell, nn.BatchNorm1d): - cell.gamma.set_data(init.initializer(init.Constant(1), cell.gamma.shape, cell.gamma.dtype)) - if cell.beta is not None: - cell.beta.set_data(init.initializer(init.Constant(0), cell.beta.shape, cell.gamma.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d) or isinstance(cell, mint.nn.BatchNorm1d): + cell.weight.set_data( + init.initializer(init.Constant(1), cell.weight.shape, cell.weight.dtype)) + if cell.bias is not None: + cell.bias.set_data( + init.initializer(init.Constant(0), cell.bias.shape, cell.weight.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.HeUniform(math.sqrt(5), mode='fan_in', nonlinearity='leaky_relu'), cell.weight.shape, cell.weight.dtype)) diff --git a/mindcv/models/helpers.py b/mindcv/models/helpers.py index c7e02a033..16659357d 100644 --- a/mindcv/models/helpers.py +++ b/mindcv/models/helpers.py @@ -147,6 +147,24 @@ def load_model_checkpoint(model: nn.Cell, checkpoint_path: str = "", ema: bool = if os.path.exists(checkpoint_path): checkpoint_param = load_checkpoint(checkpoint_path) + rename_map = { + "beta": "bias", + "gamma": "weight", + "moving_mean": "running_mean", + "moving_variance": "running_var", + } + + # Rename parameters in-place + keys_to_rename = list(checkpoint_param.keys()) + for key in keys_to_rename: + for old_suffix, new_suffix in rename_map.items(): + if key.endswith(old_suffix): + # Replace the old suffix with the new one + new_key = key[: -len(old_suffix)] + new_suffix + print(f"Renaming {key} -> {new_key}") + checkpoint_param[new_key] = checkpoint_param.pop(key) + break # Exit loop after renaming + if auto_mapping: checkpoint_param = auto_map(model, checkpoint_param) diff --git a/mindcv/models/inceptionv3.py b/mindcv/models/inceptionv3.py index 8e15c5dfa..6899af181 100644 --- a/mindcv/models/inceptionv3.py +++ b/mindcv/models/inceptionv3.py @@ -3,13 +3,14 @@ Refer to Rethinking the Inception Architecture for Computer Vision. """ -from typing import Tuple, Union +from typing import Any, Tuple, Union import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained from .layers.compatibility import Dropout +from .layers.flatten import Flatten from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -37,20 +38,11 @@ def _cfg(url="", **kwargs): class BasicConv2d(nn.Cell): """A block for conv bn and relu""" - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple] = 1, - stride: int = 1, - padding: int = 0, - pad_mode: str = "same", - ) -> None: + def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, - padding=padding, pad_mode=pad_mode) - self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.9997) - self.relu = nn.ReLU() + self.conv = mint.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = mint.nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.0003) + self.relu = mint.nn.ReLU() def construct(self, x: Tensor) -> Tensor: x = self.conv(x) @@ -69,16 +61,16 @@ def __init__( self.branch0 = BasicConv2d(in_channels, 64, kernel_size=1) self.branch1 = nn.SequentialCell([ BasicConv2d(in_channels, 48, kernel_size=1), - BasicConv2d(48, 64, kernel_size=5) + BasicConv2d(48, 64, kernel_size=5, padding=2) ]) self.branch2 = nn.SequentialCell([ BasicConv2d(in_channels, 64, kernel_size=1), - BasicConv2d(64, 96, kernel_size=3), - BasicConv2d(96, 96, kernel_size=3) + BasicConv2d(64, 96, kernel_size=3, padding=1), + BasicConv2d(96, 96, kernel_size=3, padding=1) ]) self.branch_pool = nn.SequentialCell([ - nn.AvgPool2d(kernel_size=3, pad_mode="same"), + mint.nn.AvgPool2d(kernel_size=3, stride=1, padding=1), BasicConv2d(in_channels, pool_features, kernel_size=1) ]) @@ -87,27 +79,27 @@ def construct(self, x: Tensor) -> Tensor: x1 = self.branch1(x) x2 = self.branch2(x) branch_pool = self.branch_pool(x) - out = ops.concat((x0, x1, x2, branch_pool), axis=1) + out = mint.concat((x0, x1, x2, branch_pool), dim=1) return out class InceptionB(nn.Cell): def __init__(self, in_channels: int) -> None: super().__init__() - self.branch0 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2, pad_mode='valid') + self.branch0 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) self.branch1 = nn.SequentialCell([ BasicConv2d(in_channels, 64, kernel_size=1), - BasicConv2d(64, 96, kernel_size=3), - BasicConv2d(96, 96, kernel_size=3, stride=2, pad_mode="valid") + BasicConv2d(64, 96, kernel_size=3, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=2) ]) - self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + self.branch_pool = mint.nn.MaxPool2d(kernel_size=3, stride=2) def construct(self, x: Tensor) -> Tensor: x0 = self.branch0(x) x1 = self.branch1(x) branch_pool = self.branch_pool(x) - out = ops.concat((x0, x1, branch_pool), axis=1) + out = mint.concat((x0, x1, branch_pool), dim=1) return out @@ -121,18 +113,18 @@ def __init__( self.branch0 = BasicConv2d(in_channels, 192, kernel_size=1) self.branch1 = nn.SequentialCell([ BasicConv2d(in_channels, channels_7x7, kernel_size=1), - BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)), - BasicConv2d(channels_7x7, 192, kernel_size=(7, 1)) + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7), padding=(0, 3)), + BasicConv2d(channels_7x7, 192, kernel_size=(7, 1), padding=(3, 0)) ]) self.branch2 = nn.SequentialCell([ BasicConv2d(in_channels, channels_7x7, kernel_size=1), - BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)), - BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7)), - BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1)), - BasicConv2d(channels_7x7, 192, kernel_size=(1, 7)) + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1), padding=(3, 0)), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(1, 7), padding=(0, 3)), + BasicConv2d(channels_7x7, channels_7x7, kernel_size=(7, 1), padding=(3, 0)), + BasicConv2d(channels_7x7, 192, kernel_size=(1, 7), padding=(0, 3)) ]) self.branch_pool = nn.SequentialCell([ - nn.AvgPool2d(kernel_size=3, pad_mode="same"), + mint.nn.AvgPool2d(kernel_size=3, stride=1, padding=1), BasicConv2d(in_channels, 192, kernel_size=1) ]) @@ -141,7 +133,7 @@ def construct(self, x: Tensor) -> Tensor: x1 = self.branch1(x) x2 = self.branch2(x) branch_pool = self.branch_pool(x) - out = ops.concat((x0, x1, x2, branch_pool), axis=1) + out = mint.concat((x0, x1, x2, branch_pool), dim=1) return out @@ -150,21 +142,21 @@ def __init__(self, in_channels: int) -> None: super().__init__() self.branch0 = nn.SequentialCell([ BasicConv2d(in_channels, 192, kernel_size=1), - BasicConv2d(192, 320, kernel_size=3, stride=2, pad_mode="valid") + BasicConv2d(192, 320, kernel_size=3, stride=2) ]) self.branch1 = nn.SequentialCell([ BasicConv2d(in_channels, 192, kernel_size=1), - BasicConv2d(192, 192, kernel_size=(1, 7)), # check - BasicConv2d(192, 192, kernel_size=(7, 1)), - BasicConv2d(192, 192, kernel_size=3, stride=2, pad_mode="valid") + BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), # check + BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)), + BasicConv2d(192, 192, kernel_size=3, stride=2) ]) - self.branch_pool = nn.MaxPool2d(kernel_size=3, stride=2) + self.branch_pool = mint.nn.MaxPool2d(kernel_size=3, stride=2) def construct(self, x: Tensor) -> Tensor: x0 = self.branch0(x) x1 = self.branch1(x) branch_pool = self.branch_pool(x) - out = ops.concat((x0, x1, branch_pool), axis=1) + out = mint.concat((x0, x1, branch_pool), dim=1) return out @@ -173,27 +165,27 @@ def __init__(self, in_channels: int) -> None: super().__init__() self.branch0 = BasicConv2d(in_channels, 320, kernel_size=1) self.branch1 = BasicConv2d(in_channels, 384, kernel_size=1) - self.branch1a = BasicConv2d(384, 384, kernel_size=(1, 3)) - self.branch1b = BasicConv2d(384, 384, kernel_size=(3, 1)) + self.branch1a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch1b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch2 = nn.SequentialCell([ BasicConv2d(in_channels, 448, kernel_size=1), - BasicConv2d(448, 384, kernel_size=3) + BasicConv2d(448, 384, kernel_size=3, padding=1) ]) - self.branch2a = BasicConv2d(384, 384, kernel_size=(1, 3)) - self.branch2b = BasicConv2d(384, 384, kernel_size=(3, 1)) + self.branch2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch_pool = nn.SequentialCell([ - nn.AvgPool2d(kernel_size=3, pad_mode="same"), + mint.nn.AvgPool2d(kernel_size=3, stride=1, padding=1), BasicConv2d(in_channels, 192, kernel_size=1) ]) def construct(self, x: Tensor) -> Tensor: x0 = self.branch0(x) x1 = self.branch1(x) - x1 = ops.concat((self.branch1a(x1), self.branch1b(x1)), axis=1) + x1 = mint.concat((self.branch1a(x1), self.branch1b(x1)), dim=1) x2 = self.branch2(x) - x2 = ops.concat((self.branch2a(x2), self.branch2b(x2)), axis=1) + x2 = mint.concat((self.branch2a(x2), self.branch2b(x2)), dim=1) branch_pool = self.branch_pool(x) - out = ops.concat((x0, x1, x2, branch_pool), axis=1) + out = mint.concat((x0, x1, x2, branch_pool), dim=1) return out @@ -206,11 +198,11 @@ def __init__( num_classes: int, ) -> None: super().__init__() - self.avg_pool = nn.AvgPool2d(5, stride=3, pad_mode="valid") + self.avg_pool = mint.nn.AvgPool2d(5, stride=3) self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) - self.conv1 = BasicConv2d(128, 768, kernel_size=5, pad_mode="valid") - self.flatten = nn.Flatten() - self.fc = nn.Dense(in_channels, num_classes) + self.conv1 = BasicConv2d(128, 768, kernel_size=5) + self.flatten = Flatten() + self.fc = mint.nn.Linear(in_channels, num_classes) def construct(self, x: Tensor) -> Tensor: x = self.avg_pool(x) @@ -245,13 +237,13 @@ def __init__( ) -> None: super().__init__() self.aux_logits = aux_logits - self.conv1a = BasicConv2d(in_channels, 32, kernel_size=3, stride=2, pad_mode="valid") - self.conv2a = BasicConv2d(32, 32, kernel_size=3, stride=1, pad_mode="valid") - self.conv2b = BasicConv2d(32, 64, kernel_size=3, stride=1) - self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.conv1a = BasicConv2d(in_channels, 32, kernel_size=3, stride=2) + self.conv2a = BasicConv2d(32, 32, kernel_size=3) + self.conv2b = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.maxpool1 = mint.nn.MaxPool2d(kernel_size=3, stride=2) self.conv3b = BasicConv2d(64, 80, kernel_size=1) - self.conv4a = BasicConv2d(80, 192, kernel_size=3, pad_mode="valid") - self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.conv4a = BasicConv2d(80, 192, kernel_size=3) + self.maxpool2 = mint.nn.MaxPool2d(kernel_size=3, stride=2) self.inception5b = InceptionA(192, pool_features=32) self.inception5c = InceptionA(256, pool_features=64) self.inception5d = InceptionA(288, pool_features=64) @@ -269,13 +261,13 @@ def __init__( self.pool = GlobalAvgPooling() self.dropout = Dropout(p=drop_rate) self.num_features = 2048 - self.classifier = nn.Dense(self.num_features, num_classes) + self.classifier = mint.nn.Linear(self.num_features, num_classes) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.XavierUniform(), cell.weight.shape, cell.weight.dtype)) diff --git a/mindcv/models/layers/__init__.py b/mindcv/models/layers/__init__.py index c3e4de210..e29c44a69 100644 --- a/mindcv/models/layers/__init__.py +++ b/mindcv/models/layers/__init__.py @@ -1,25 +1,37 @@ """layers init""" from . import ( activation, + compatibility, conv_norm_act, drop_path, + extend_bmm, + flatten, format, identity, + l2normalize, + pad, patch_dropout, pooling, pos_embed, selective_kernel, + sigmoid, squeeze_excite, ) from .activation import * +from .compatibility import * from .conv_norm_act import * from .drop_path import * +from .extend_bmm import * +from .flatten import * from .format import * from .identity import * +from .l2normalize import * +from .pad import * from .patch_dropout import * from .pooling import * from .pos_embed import * from .selective_kernel import * +from .sigmoid import * from .squeeze_excite import * __all__ = [] diff --git a/mindcv/models/layers/compatibility.py b/mindcv/models/layers/compatibility.py index 8aecbae77..a416a29b1 100644 --- a/mindcv/models/layers/compatibility.py +++ b/mindcv/models/layers/compatibility.py @@ -1,7 +1,7 @@ import inspect import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn, ops __all__ = [ "Dropout", @@ -83,7 +83,7 @@ def __init__(self, split_size_or_sections, output_num, axis=0): ) def construct(self, x): - return ops.split(x, **self.kwargs) + return mint.split(x, **self.kwargs) class ResizeBilinear(nn.Cell): diff --git a/mindcv/models/layers/conv_norm_act.py b/mindcv/models/layers/conv_norm_act.py index e141affc3..e494ff1fa 100644 --- a/mindcv/models/layers/conv_norm_act.py +++ b/mindcv/models/layers/conv_norm_act.py @@ -1,7 +1,7 @@ """ Conv2d + BN + Act""" from typing import Optional -from mindspore import nn +from mindspore import mint, nn class Conv2dNormActivation(nn.Cell): @@ -13,37 +13,32 @@ def __init__( out_channels: int, kernel_size: int = 3, stride: int = 1, - pad_mode: str = "pad", padding: Optional[int] = None, dilation: int = 1, groups: int = 1, - norm: Optional[nn.Cell] = nn.BatchNorm2d, - activation: Optional[nn.Cell] = nn.ReLU, + norm: Optional[nn.Cell] = mint.nn.BatchNorm2d, + activation: Optional[nn.Cell] = mint.nn.ReLU, has_bias: Optional[bool] = None, **kwargs ) -> None: super().__init__() - if pad_mode == "pad": - if padding is None: - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - else: - padding = 0 + if padding is None: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 if has_bias is None: has_bias = norm is None layers = [ - nn.Conv2d( + mint.nn.Conv2d( in_channels, out_channels, kernel_size, stride, - pad_mode=pad_mode, padding=padding, dilation=dilation, - group=groups, - has_bias=has_bias, + groups=groups, + bias=has_bias, **kwargs ) ] diff --git a/mindcv/models/layers/drop_path.py b/mindcv/models/layers/drop_path.py index ea0374734..ef3b61cd6 100644 --- a/mindcv/models/layers/drop_path.py +++ b/mindcv/models/layers/drop_path.py @@ -3,11 +3,9 @@ Papers: Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) """ -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from mindspore.numpy import ones -from .compatibility import Dropout - class DropPath(nn.Cell): """DropPath (Stochastic Depth) regularization layers""" @@ -20,7 +18,7 @@ def __init__( super().__init__() self.keep_prob = 1.0 - drop_prob self.scale_by_keep = scale_by_keep - self.dropout = Dropout(p=drop_prob) + self.dropout = mint.nn.Dropout(p=drop_prob) def construct(self, x: Tensor) -> Tensor: if self.keep_prob == 1.0 or not self.training: @@ -28,5 +26,5 @@ def construct(self, x: Tensor) -> Tensor: shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = self.dropout(ones(shape)) if not self.scale_by_keep: - random_tensor = ops.mul(random_tensor, self.keep_prob) + random_tensor = mint.mul(random_tensor, self.keep_prob) return x * random_tensor diff --git a/mindcv/models/layers/extend_bmm.py b/mindcv/models/layers/extend_bmm.py new file mode 100644 index 000000000..052ca4f0c --- /dev/null +++ b/mindcv/models/layers/extend_bmm.py @@ -0,0 +1,27 @@ +""" Extended Batch MatMul Module""" +from mindspore import mint, nn + + +class ExtendBatchMatMul(nn.Cell): + """ + Extend Batch MatMul Module to deal with batch matrix multiplication between tensors with higher dimensions + """ + + def __init__(self, transpose_a=False, transpose_b=False) -> None: + super().__init__() + self.transpose_a = transpose_a + self.transpose_b = transpose_b + + def construct(self, a, b): + if self.transpose_a: + a = mint.transpose(a, -1, -2) + if self.transpose_b: + b = mint.transpose(b, -1, -2) + size = len(a.shape) + if size <= 3: + return mint.bmm(a, b) + output_shape = (*a.shape[:-2], a.shape[-2], b.shape[-1]) + a = mint.reshape(a, (-1, *a.shape[-2:])) + b = mint.reshape(b, (-1, *b.shape[-2:])) + + return mint.reshape(mint.bmm(a, b), output_shape) diff --git a/mindcv/models/layers/flatten.py b/mindcv/models/layers/flatten.py new file mode 100644 index 000000000..575a84ba9 --- /dev/null +++ b/mindcv/models/layers/flatten.py @@ -0,0 +1,15 @@ +""" Flatten Module""" +from mindspore import Tensor, nn + + +class Flatten(nn.Cell): + """ + Flattens a contiguous range of dims into a tensor. + """ + def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: + super().__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def construct(self, input: Tensor) -> Tensor: + return input.flatten(start_dim=self.start_dim, end_dim=self.end_dim) diff --git a/mindcv/models/layers/format.py b/mindcv/models/layers/format.py index 058a74517..3cfb8e998 100644 --- a/mindcv/models/layers/format.py +++ b/mindcv/models/layers/format.py @@ -2,6 +2,7 @@ from typing import Union import mindspore +from mindspore import mint class Format(str, Enum): @@ -16,19 +17,19 @@ class Format(str, Enum): def nchw_to(x: mindspore.Tensor, fmt: Format): if fmt == Format.NHWC: - x = x.permute(0, 2, 3, 1) + x = mint.permute(x, (0, 2, 3, 1)) elif fmt == Format.NLC: - x = x.flatten(start_dim=2).transpose((0, 2, 1)) + x = mint.permute(mint.flatten(x), (0, 2, 1)) elif fmt == Format.NCL: - x = x.flatten(start_dim=2) + x = mint.flatten(x, start_dim=2) return x def nhwc_to(x: mindspore.Tensor, fmt: Format): if fmt == Format.NCHW: - x = x.permute(0, 3, 1, 2) + x = mint.permute(x, (0, 3, 1, 2)) elif fmt == Format.NLC: - x = x.flatten(start_dim=1, end_dim=2) + x = mint.flatten(x, start_dim=1, end_dim=2) elif fmt == Format.NCL: - x = x.flatten(start_dim=1, end_dim=2).transpose((0, 2, 1)) + x = mint.permute(mint.flatten(x, start_dim=1, end_dim=2), (0, 2, 1)) return x diff --git a/mindcv/models/layers/l2normalize.py b/mindcv/models/layers/l2normalize.py new file mode 100644 index 000000000..1eeb6ed26 --- /dev/null +++ b/mindcv/models/layers/l2normalize.py @@ -0,0 +1,25 @@ +""" L2Normalize Module""" +from mindspore import mint + + +class L2Normalize: + def __init__(self, axis=-1, epsilon=1e-12): + """ + Initializes the L2Normalize class in PyTorch. + + :param axis: Specifies the axis along which normalization is applied, default is -1 (last axis). + :param epsilon: A small value added to the norm to avoid division by zero, default is 1e-12. + """ + self.axis = axis + self.epsilon = epsilon + + def __call__(self, input_tensor): + """ + Applies L2 normalization to the input tensor. + + :param input_tensor: The input tensor to be normalized. + :return: The L2 normalized tensor. + """ + norm = mint.sqrt(mint.sum(input_tensor ** 2, dim=self.axis, keepdim=True) + self.epsilon) + output_tensor = input_tensor / norm + return output_tensor diff --git a/mindcv/models/layers/mlp.py b/mindcv/models/layers/mlp.py index 7da27a4a2..485e9be06 100644 --- a/mindcv/models/layers/mlp.py +++ b/mindcv/models/layers/mlp.py @@ -2,9 +2,7 @@ """ from typing import Optional -from mindspore import Tensor, nn - -from .compatibility import Dropout +from mindspore import Tensor, mint, nn class Mlp(nn.Cell): @@ -13,16 +11,16 @@ def __init__( in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, - act_layer: Optional[nn.Cell] = nn.GELU, + act_layer: Optional[nn.Cell] = mint.nn.GELU, drop: float = 0.0, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=True) + self.fc1 = mint.nn.Linear(in_features, hidden_features, bias=True) self.act = act_layer() - self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=True) - self.drop = Dropout(p=drop) + self.fc2 = mint.nn.Linear(hidden_features, out_features, bias=True) + self.drop = mint.nn.Dropout(p=drop) def construct(self, x: Tensor) -> Tensor: x = self.fc1(x) diff --git a/mindcv/models/layers/pad.py b/mindcv/models/layers/pad.py new file mode 100644 index 000000000..a91c2e1a1 --- /dev/null +++ b/mindcv/models/layers/pad.py @@ -0,0 +1,18 @@ +""" Pad Module""" +import mindspore.mint.nn.functional as F +from mindspore import nn + + +class Pad(nn.Cell): + """ + Pad Module to pad the input tensor according to the paddings and mode. + """ + def __init__(self, pad, mode='constant', value=0.0) -> None: + super().__init__() + self.pad = pad + self.mode = mode + self.value = value + + def construct(self, x): + x = F.pad(x, self.pad, self.mode, self.value) + return x diff --git a/mindcv/models/layers/patch_dropout.py b/mindcv/models/layers/patch_dropout.py index ad854dbfc..9f1ec37d7 100644 --- a/mindcv/models/layers/patch_dropout.py +++ b/mindcv/models/layers/patch_dropout.py @@ -1,7 +1,7 @@ import numpy as np import mindspore as ms -from mindspore import nn, ops +from mindspore import mint, nn class PatchDropout(nn.Cell): @@ -21,7 +21,6 @@ def __init__( self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) self.ordered = ordered self.return_indices = return_indices - self.sort = ops.Sort() def forward(self, x): if not self.training or self.prob == 0.: @@ -37,17 +36,17 @@ def forward(self, x): B = x.shape[0] L = x.shape[1] num_keep = max(1, int(L * (1. - self.prob))) - _, indices = self.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32)) + _, indices = mint.sort(ms.Tensor(np.random.rand(B, L)).astype(ms.float32)) keep_indices = indices[:, :num_keep] if self.ordered: # NOTE does not need to maintain patch order in typical transformer use, # but possibly useful for debug / visualization - keep_indices, _ = self.sort(keep_indices) - keep_indices = ops.broadcast_to(ops.expand_dims(keep_indices, axis=-1), (-1, -1, x.shape[2])) - x = ops.gather_elements(x, dim=1, index=keep_indices) + keep_indices, _ = mint.sort(keep_indices) + keep_indices = mint.broadcast_to(mint.unsqueeze(keep_indices, dim=-1), (-1, -1, x.shape[2])) + x = mint.gather(x, dim=1, index=keep_indices) if prefix_tokens is not None: - x = ops.concat((prefix_tokens, x), axis=1) + x = mint.concat((prefix_tokens, x), dim=1) if self.return_indices: return x, keep_indices diff --git a/mindcv/models/layers/patch_embed.py b/mindcv/models/layers/patch_embed.py index 661e07890..0cc1436c1 100644 --- a/mindcv/models/layers/patch_embed.py +++ b/mindcv/models/layers/patch_embed.py @@ -2,7 +2,8 @@ A convolution based approach to patchifying a 2D image w/ embedding projection.""" from typing import Optional -from mindspore import Tensor, nn, ops +import mindspore.mint.nn.functional as F +from mindspore import Tensor, mint, nn from .format import Format, nchw_to from .helpers import to_2tuple @@ -55,8 +56,9 @@ def __init__( self.dynamic_img_pad = dynamic_img_pad self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - pad_mode='pad', has_bias=bias, weight_init="TruncatedNormal") + self.proj = mint.nn.Conv2d( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias + ) if norm_layer is not None: if isinstance(embed_dim, int): @@ -81,13 +83,13 @@ def construct(self, x: Tensor) -> Tensor: if self.dynamic_img_pad: pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] - x = ops.pad(x, (0, pad_w, 0, pad_h)) + x = F.pad(x, (0, pad_w, 0, pad_h)) # FIXME look at relaxing size constraints x = self.proj(x) if self.flatten: - x = ops.Reshape()(x, (B, self.embed_dim, -1)) # B Ph*Pw C - x = ops.Transpose()(x, (0, 2, 1)) + x = mint.reshape(x, (B, self.embed_dim, -1)) # B Ph*Pw C + x = mint.permute(x, (0, 2, 1)) elif self.output_fmt != "NCHW": x = nchw_to(x, self.output_fmt) if self.norm is not None: diff --git a/mindcv/models/layers/pooling.py b/mindcv/models/layers/pooling.py index 4ad486bc8..68ddfcdfc 100644 --- a/mindcv/models/layers/pooling.py +++ b/mindcv/models/layers/pooling.py @@ -1,5 +1,5 @@ """ GlobalAvgPooling Module""" -from mindspore import nn, ops +from mindspore import mint, nn class GlobalAvgPooling(nn.Cell): @@ -12,5 +12,5 @@ def __init__(self, keep_dims: bool = False) -> None: self.keep_dims = keep_dims def construct(self, x): - x = ops.mean(x, axis=(2, 3), keep_dims=self.keep_dims) + x = mint.mean(x, dim=(2, 3), keepdim=self.keep_dims) return x diff --git a/mindcv/models/layers/pos_embed.py b/mindcv/models/layers/pos_embed.py index ba4548580..0ceaefc5f 100644 --- a/mindcv/models/layers/pos_embed.py +++ b/mindcv/models/layers/pos_embed.py @@ -5,7 +5,8 @@ import numpy as np import mindspore as ms -from mindspore import Parameter, Tensor, nn, ops +import mindspore.mint.nn.functional as F +from mindspore import Parameter, Tensor, mint, nn from .compatibility import Interpolate @@ -36,15 +37,15 @@ def resample_abs_pos_embed( # do the interpolation embed_dim = posemb.shape[-1] orig_dtype = posemb.dtype - posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + posemb = mint.permute(mint.reshape(posemb, (1, old_size[0], old_size[1], -1)), (0, 3, 1, 2)) interpolate = Interpolate(mode=interpolation, align_corners=True) posemb = interpolate(posemb, size=new_size) - posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = mint.reshape(mint.permute(posemb, (0, 2, 3, 1)), (1, -1, embed_dim)) posemb = posemb.astype(orig_dtype) # add back extra (class, etc) prefix tokens if posemb_prefix is not None: - posemb = ops.concatcat((posemb_prefix, posemb), axis=1) + posemb = mint.concat((posemb_prefix, posemb), axis=1) return posemb @@ -82,12 +83,12 @@ def __init__( relative_position_index[0, 0] = num_relative_distance - 1 relative_position_index = Tensor(relative_position_index.reshape(-1)) - self.one_hot = nn.OneHot(axis=-1, depth=num_relative_distance, dtype=ms.float16) - self.relative_position_index = Parameter(self.one_hot(relative_position_index), requires_grad=False) + self.relative_position_index = Parameter( + F.one_hot(relative_position_index, num_relative_distance), requires_grad=False) def construct(self): - out = ops.matmul(self.relative_position_index, self.relative_position_bias_table) - out = ops.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) - out = ops.transpose(out, (2, 0, 1)) - out = ops.expand_dims(out, 0) + out = mint.matmul(self.relative_position_index, self.relative_position_bias_table) + out = mint.reshape(out, (self.num_tokens + 1, self.num_tokens + 1, -1)) + out = mint.permute(out, (2, 0, 1)) + out = mint.unsqueeze(out, 0) return out diff --git a/mindcv/models/layers/selective_kernel.py b/mindcv/models/layers/selective_kernel.py index ddf6ebcad..380a1f836 100644 --- a/mindcv/models/layers/selective_kernel.py +++ b/mindcv/models/layers/selective_kernel.py @@ -3,10 +3,9 @@ """ from typing import List, Optional, Union -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from ..helpers import make_divisible -from .compatibility import Split from .conv_norm_act import Conv2dNormActivation from .pooling import GlobalAvgPooling @@ -31,17 +30,17 @@ def __init__( channels: int, num_paths: int = 2, attn_channels: int = 32, - activation: Optional[nn.Cell] = nn.ReLU, - norm: Optional[nn.Cell] = nn.BatchNorm2d, + activation: Optional[nn.Cell] = mint.nn.ReLU, + norm: Optional[nn.Cell] = mint.nn.BatchNorm2d, ): super().__init__() self.num_paths = num_paths self.mean = GlobalAvgPooling(keep_dims=True) - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, has_bias=False) + self.fc_reduce = mint.nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) self.bn = norm(attn_channels) self.act = activation() - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1) - self.softmax = nn.Softmax(axis=1) + self.fc_select = mint.nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + self.softmax = mint.nn.Softmax(dim=1) def construct(self, x: Tensor) -> Tensor: x = self.mean((x.sum(1))) @@ -92,8 +91,8 @@ def __init__( rd_divisor: int = 8, keep_3x3: bool = True, split_input: bool = True, - activation: Optional[nn.Cell] = nn.ReLU, - norm: Optional[nn.Cell] = nn.BatchNorm2d, + activation: Optional[nn.Cell] = mint.nn.ReLU, + norm: Optional[nn.Cell] = mint.nn.BatchNorm2d, ): super().__init__() out_channels = out_channels or in_channels @@ -114,8 +113,6 @@ def __init__( assert in_channels % self.num_paths == 0 in_channels = in_channels // self.num_paths groups = min(out_channels, groups) - self.split = Split(split_size_or_sections=self.in_channels // self.num_paths, output_num=self.num_paths, axis=1) - self.paths = nn.CellList([ Conv2dNormActivation(in_channels, out_channels, kernel_size=k, stride=stride, groups=groups, dilation=d, activation=activation, norm=norm) @@ -128,14 +125,14 @@ def __init__( def construct(self, x: Tensor) -> Tensor: x_paths = [] if self.split_input: - x_split = self.split(x) + x_split = mint.split(x, split_size_or_sections=self.in_channels // self.num_paths, dim=1) for i, op in enumerate(self.paths): x_paths.append(op(x_split[i])) else: for op in self.paths: x_paths.append(op(x)) - x = ops.stack(x_paths, axis=1) + x = mint.stack(x_paths, dim=1) x_attn = self.attn(x) x = x * x_attn x = x.sum(1) diff --git a/mindcv/models/layers/sigmoid.py b/mindcv/models/layers/sigmoid.py new file mode 100644 index 000000000..7b260d576 --- /dev/null +++ b/mindcv/models/layers/sigmoid.py @@ -0,0 +1,8 @@ +""" Sigmoid Module""" +from mindspore import mint, nn + + +class Sigmoid(nn.Cell): + def construct(self, x): + x = mint.sigmoid(x) + return x diff --git a/mindcv/models/layers/squeeze_excite.py b/mindcv/models/layers/squeeze_excite.py index b43445ab8..8c8ca0401 100644 --- a/mindcv/models/layers/squeeze_excite.py +++ b/mindcv/models/layers/squeeze_excite.py @@ -5,10 +5,11 @@ """ from typing import Optional -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from ..helpers import make_divisible from .pooling import GlobalAvgPooling +from .sigmoid import Sigmoid class SqueezeExcite(nn.Cell): @@ -27,8 +28,8 @@ def __init__( rd_channels: Optional[int] = None, rd_divisor: int = 8, norm: Optional[nn.Cell] = None, - act_layer: nn.Cell = nn.ReLU, - gate_layer: nn.Cell = nn.Sigmoid, + act_layer: nn.Cell = mint.nn.ReLU, + gate_layer: nn.Cell = Sigmoid, ) -> None: super().__init__() self.norm = norm @@ -37,19 +38,19 @@ def __init__( if not rd_channels: rd_channels = make_divisible(in_channels * rd_ratio, rd_divisor) - self.conv_reduce = nn.Conv2d( + self.conv_reduce = mint.nn.Conv2d( in_channels=in_channels, out_channels=rd_channels, kernel_size=1, - has_bias=True, + bias=True, ) if self.norm: - self.bn = nn.BatchNorm2d(rd_channels) - self.conv_expand = nn.Conv2d( + self.bn = mint.nn.BatchNorm2d(rd_channels) + self.conv_expand = mint.nn.Conv2d( in_channels=rd_channels, out_channels=in_channels, kernel_size=1, - has_bias=True, + bias=True, ) self.pool = GlobalAvgPooling(keep_dims=True) @@ -77,8 +78,8 @@ def __init__( rd_channels: Optional[int] = None, rd_divisor: int = 8, norm: Optional[nn.Cell] = None, - act_layer: nn.Cell = nn.ReLU, - gate_layer: nn.Cell = nn.Sigmoid, + act_layer: nn.Cell = mint.nn.ReLU, + gate_layer: nn.Cell = Sigmoid, ) -> None: super().__init__() self.norm = norm @@ -87,17 +88,17 @@ def __init__( if not rd_channels: rd_channels = make_divisible(in_channels * rd_ratio, rd_divisor) - self.conv_reduce = nn.Dense( - in_channels=in_channels, - out_channels=rd_channels, - has_bias=True, + self.conv_reduce = mint.nn.Linear( + in_features=in_channels, + out_features=rd_channels, + bias=True, ) if self.norm: - self.bn = nn.BatchNorm2d(rd_channels) - self.conv_expand = nn.Dense( - in_channels=rd_channels, - out_channels=in_channels, - has_bias=True, + self.bn = mint.nn.BatchNorm2d(rd_channels) + self.conv_expand = mint.nn.Linear( + in_features=rd_channels, + out_features=in_channels, + bias=True, ) self.pool = GlobalAvgPooling(keep_dims=False) @@ -109,7 +110,7 @@ def construct(self, x: Tensor) -> Tensor: x_se = self.act(x_se) x_se = self.conv_expand(x_se) x_se = self.gate(x_se) - x_se = ops.expand_dims(x_se, -1) - x_se = ops.expand_dims(x_se, -1) + x_se = mint.unsqueeze(x_se, -1) + x_se = mint.unsqueeze(x_se, -1) x = x * x_se return x diff --git a/mindcv/models/mobilenetv3.py b/mindcv/models/mobilenetv3.py index 6d911d4e8..2fb8f5495 100644 --- a/mindcv/models/mobilenetv3.py +++ b/mindcv/models/mobilenetv3.py @@ -6,10 +6,9 @@ import math import mindspore.common.initializer as init -from mindspore import Tensor, nn +from mindspore import Tensor, mint, nn from .helpers import build_model_with_cfg, make_divisible -from .layers.compatibility import Dropout from .layers.pooling import GlobalAvgPooling from .layers.squeeze_excite import SqueezeExcite from .registry import register_model @@ -62,32 +61,32 @@ def __init__( self.use_se = use_se self.use_res_connect = stride == 1 and in_channels == out_channels assert activation in ["relu", "hswish"] - self.activation = nn.HSwish if activation == "hswish" else nn.ReLU + self.activation = mint.nn.Hardswish if activation == "hswish" else mint.nn.ReLU layers = [] # Expand. if in_channels != mid_channels: layers.extend([ - nn.Conv2d(in_channels, mid_channels, 1, 1, pad_mode="pad", padding=0, has_bias=False), - nn.BatchNorm2d(mid_channels), + mint.nn.Conv2d(in_channels, mid_channels, 1, 1, padding=0, bias=False), + mint.nn.BatchNorm2d(mid_channels), self.activation(), ]) # DepthWise. layers.extend([ - nn.Conv2d(mid_channels, mid_channels, kernel_size, stride, - pad_mode="same", group=mid_channels, has_bias=False), - nn.BatchNorm2d(mid_channels), + mint.nn.Conv2d(mid_channels, mid_channels, kernel_size, stride, + padding=kernel_size // 2, groups=mid_channels, bias=False), + mint.nn.BatchNorm2d(mid_channels), self.activation(), ]) # SqueezeExcitation. if use_se: layers.append( - SqueezeExcite(mid_channels, 1.0 / 4, act_layer=nn.ReLU, gate_layer=nn.HSigmoid) + SqueezeExcite(mid_channels, 1.0 / 4, act_layer=mint.nn.ReLU, gate_layer=mint.nn.Hardsigmoid) ) # Project. layers.extend([ - nn.Conv2d(mid_channels, out_channels, 1, 1, pad_mode="pad", padding=0, has_bias=False), - nn.BatchNorm2d(out_channels), + mint.nn.Conv2d(mid_channels, out_channels, 1, 1, padding=0, bias=False), + mint.nn.BatchNorm2d(out_channels), ]) self.layers = nn.SequentialCell(layers) @@ -165,9 +164,9 @@ def __init__( # Building stem conv layer. features = [ - nn.Conv2d(in_channels, input_channels, 3, 2, pad_mode="pad", padding=1, has_bias=False), - nn.BatchNorm2d(input_channels), - nn.HSwish(), + mint.nn.Conv2d(in_channels, input_channels, 3, 2, padding=1, bias=False), + mint.nn.BatchNorm2d(input_channels), + mint.nn.Hardswish(), ] total_reduction = 2 @@ -188,9 +187,9 @@ def __init__( # Building last point-wise conv layers. output_channels = input_channels * 6 features.extend([ - nn.Conv2d(input_channels, output_channels, 1, 1, pad_mode="pad", padding=0, has_bias=False), - nn.BatchNorm2d(output_channels), - nn.HSwish(), + mint.nn.Conv2d(input_channels, output_channels, 1, 1, padding=0, bias=False), + mint.nn.BatchNorm2d(output_channels), + mint.nn.Hardswish(), ]) self.feature_info.append(dict(chs=output_channels, reduction=total_reduction, @@ -201,27 +200,27 @@ def __init__( self.pool = GlobalAvgPooling() self.classifier = nn.SequentialCell([ - nn.Dense(output_channels, last_channels), - nn.HSwish(), - Dropout(p=0.2), - nn.Dense(last_channels, num_classes), + mint.nn.Linear(output_channels, last_channels), + mint.nn.Hardswish(), + mint.nn.Dropout(p=0.2), + mint.nn.Linear(last_channels, num_classes), ]) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): n = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels cell.weight.set_data( init.initializer(init.Normal(sigma=math.sqrt(2. / n), mean=0.0), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.Normal(sigma=0.01, mean=0.0), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: diff --git a/mindcv/models/model_factory.py b/mindcv/models/model_factory.py index 964cc29cb..b5e7b3893 100644 --- a/mindcv/models/model_factory.py +++ b/mindcv/models/model_factory.py @@ -42,5 +42,7 @@ def create_model( if checkpoint_path: load_model_checkpoint(model, checkpoint_path, ema, auto_mapping) + from mindspore import mint + mint.randn() return model diff --git a/mindcv/models/pvtv2.py b/mindcv/models/pvtv2.py index c4091b7e8..ebd524bd8 100644 --- a/mindcv/models/pvtv2.py +++ b/mindcv/models/pvtv2.py @@ -7,14 +7,14 @@ import numpy as np +import mindspore.mint as mint import mindspore.nn as nn -import mindspore.ops as ops from mindspore import Tensor from mindspore.common import initializer as weight_init from .helpers import load_pretrained from .layers import DropPath, Identity -from .layers.compatibility import Dropout +from .layers.extend_bmm import ExtendBatchMatMul from .registry import register_model __all__ = [ @@ -53,13 +53,13 @@ class DWConv(nn.Cell): def __init__(self, dim=768): super(DWConv, self).__init__() - self.dwconv = nn.Conv2d(dim, dim, 3, 1, has_bias=True, group=dim) + self.dwconv = mint.nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def construct(self, x, H, W): B, N, C = x.shape - x = ops.transpose(x, (0, 2, 1)).view((B, C, H, W)) + x = mint.permute(x, (0, 2, 1)).view((B, C, H, W)) x = self.dwconv(x) - x = ops.transpose(x.view((B, C, H * W)), (0, 2, 1)) + x = mint.permute(x.view((B, C, H * W)), (0, 2, 1)) return x @@ -67,18 +67,24 @@ def construct(self, x, H, W): class Mlp(nn.Cell): """MLP with depthwise separable convolution""" - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, linear=False): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=mint.nn.GELU, + drop=0.0, + linear=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Dense(in_features, hidden_features) + self.fc1 = mint.nn.Linear(in_features, hidden_features) self.dwconv = DWConv(hidden_features) self.act = act_layer() - self.fc2 = nn.Dense(hidden_features, out_features) - self.drop = Dropout(p=drop) + self.fc2 = mint.nn.Linear(hidden_features, out_features) + self.drop = mint.nn.Dropout(p=drop) self.linear = linear if self.linear: - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() def construct(self, x, H, W): x = self.fc1(x) @@ -105,56 +111,56 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 - self.q = nn.Dense(dim, dim, has_bias=qkv_bias) - self.kv = nn.Dense(dim, dim * 2, has_bias=qkv_bias) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(dim, dim) - self.proj_drop = Dropout(p=proj_drop) - self.qk_batmatmul = ops.BatchMatMul(transpose_b=True) - self.batmatmul = ops.BatchMatMul() - self.softmax = nn.Softmax(axis=-1) + self.q = mint.nn.Linear(dim, dim, bias=qkv_bias) + self.kv = mint.nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(dim, dim) + self.proj_drop = mint.nn.Dropout(p=proj_drop) + self.qk_batmatmul = ExtendBatchMatMul(transpose_b=True) + self.batmatmul = ExtendBatchMatMul() + self.softmax = mint.nn.Softmax(dim=-1) self.linear = linear self.sr_ratio = sr_ratio if not linear: if sr_ratio > 1: - self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio, has_bias=True) - self.norm = nn.LayerNorm([dim]) + self.sr = mint.nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = mint.nn.LayerNorm([dim]) else: - self.pool = nn.AdaptiveAvgPool2d(7) - self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1, has_bias=True) - self.norm = nn.LayerNorm([dim]) - self.act = nn.GELU() + self.pool = mint.nn.AdaptiveAvgPool2d(7) + self.sr = mint.nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.norm = mint.nn.LayerNorm([dim]) + self.act = mint.nn.GELU() def construct(self, x, H, W): B, N, C = x.shape q = self.q(x) - q = ops.reshape(q, (B, N, self.num_heads, C // self.num_heads)) - q = ops.transpose(q, (0, 2, 1, 3)) + q = mint.reshape(q, (B, N, self.num_heads, C // self.num_heads)) + q = mint.permute(q, (0, 2, 1, 3)) if not self.linear: if self.sr_ratio > 1: - x_ = ops.reshape(ops.transpose(x, (0, 2, 1)), (B, C, H, W)) + x_ = mint.reshape(mint.permute(x, (0, 2, 1)), (B, C, H, W)) x_ = self.sr(x_) - x_ = ops.transpose(ops.reshape(x_, (B, C, -1)), (0, 2, 1)) + x_ = mint.permute(mint.reshape(x_, (B, C, -1)), (0, 2, 1)) x_ = self.norm(x_) kv = self.kv(x_) - kv = ops.transpose(ops.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4)) + kv = mint.permute(mint.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4)) else: kv = self.kv(x) - kv = ops.transpose(ops.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4)) + kv = mint.permute(mint.reshape(kv, (B, -1, 2, self.num_heads, C // self.num_heads)), (2, 0, 3, 1, 4)) else: - x_ = ops.reshape(ops.transpose(x, (0, 2, 1)), (B, C, H, W)) + x_ = mint.reshape(mint.permute(x, (0, 2, 1)), (B, C, H, W)) x_ = self.sr(self.pool(x_)) - x_ = ops.reshape(ops.transpose(x_, (0, 2, 1)), (B, C, -1)) + x_ = mint.reshape(mint.permute(x_, (0, 2, 1)), (B, C, -1)) x_ = self.norm(x_) x_ = self.act(x_) - kv = ops.transpose(ops.reshape(self.kv(x_), (B, -1, 2, self.num_heads, C // self.num_heads)), - (2, 0, 3, 1, 4)) + kv = mint.permute(mint.reshape(self.kv(x_), (B, -1, 2, self.num_heads, C // self.num_heads)), + (2, 0, 3, 1, 4)) k, v = kv[0], kv[1] attn = self.qk_batmatmul(q, k) * self.scale @@ -162,7 +168,7 @@ def construct(self, x, H, W): attn = self.attn_drop(attn) x = self.batmatmul(attn, v) - x = ops.reshape(ops.transpose(x, (0, 2, 1, 3)), (B, N, C)) + x = mint.reshape(mint.permute(x, (0, 2, 1, 3)), (B, N, C)) x = self.proj(x) x = self.proj_drop(x) @@ -172,8 +178,20 @@ def construct(self, x, H, W): class Block(nn.Cell): """Block with Linear Spatial Reduction Attention and Convolutional Feed-Forward""" - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False, block_id=0): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=mint.nn.GELU, + norm_layer=mint.nn.LayerNorm, + sr_ratio=1, + linear=False, + block_id=0): super().__init__() self.norm1 = norm_layer([dim]) @@ -212,13 +230,16 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7 self.patch_size = patch_size self.H, self.W = img_size[0] // stride, img_size[1] // stride self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, has_bias=True) - self.norm = nn.LayerNorm([embed_dim]) + self.proj = mint.nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2) + ) + self.norm = mint.nn.LayerNorm([embed_dim]) def construct(self, x): x = self.proj(x) B, C, H, W = x.shape - x = ops.transpose(ops.reshape(x, (B, C, H * W)), (0, 2, 1)) + x = mint.permute(mint.reshape(x, (B, C, H * W)), (0, 2, 1)) x = self.norm(x) return x, H, W @@ -241,7 +262,7 @@ class PyramidVisionTransformerV2(nn.Cell): drop_rate(float) : The drop rate for each block. Default: 0.0. attn_drop_rate(float) : The drop rate for attention. Default: 0.0. drop_path_rate(float) : The drop rate for drop path. Default: 0.0. - norm_layer(nn.Cell) : Norm layer that will be used in blocks. Default: nn.LayerNorm. + norm_layer(nn.Cell) : Norm layer that will be used in blocks. Default: mint.nn.LayerNorm. depths (list) : number of Blocks. sr_ratios(list) : stride and kernel size of each attention. num_stages(int) : number of stage. Default: 4. @@ -250,7 +271,7 @@ class PyramidVisionTransformerV2(nn.Cell): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., - attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + attn_drop_rate=0., drop_path_rate=0., norm_layer=mint.nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, linear=False): super().__init__() self.num_classes = num_classes @@ -289,7 +310,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em self.block_list = nn.CellList(block_list) self.norm_list = nn.CellList(norm_list) # classification head - self.head = nn.Dense(embed_dims[3], num_classes) if num_classes > 0 else Identity() + self.head = mint.nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else Identity() self._initialize_weights() def freeze_patch_emb(self): @@ -297,17 +318,17 @@ def freeze_patch_emb(self): def _initialize_weights(self): for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Dense): + if isinstance(cell, mint.nn.Linear): cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)) - if isinstance(cell, nn.Dense) and cell.bias is not None: + if isinstance(cell, mint.nn.Linear) and cell.bias is not None: cell.bias.set_data(weight_init.initializer(weight_init.Zero(), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.LayerNorm): - cell.gamma.set_data(weight_init.initializer(weight_init.One(), cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(weight_init.initializer(weight_init.Zero(), cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Conv2d): + elif isinstance(cell, mint.nn.LayerNorm): + cell.weight.set_data(weight_init.initializer(weight_init.One(), cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(weight_init.initializer(weight_init.Zero(), cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Conv2d): fan_out = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels - fan_out //= cell.group + fan_out //= cell.groups cell.weight.set_data(weight_init.initializer(weight_init.Normal(sigma=math.sqrt(2.0 / fan_out)), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: @@ -318,7 +339,7 @@ def get_classifier(self): def reset_classifier(self, num_classes, global_pool=""): self.num_classes = num_classes - self.head = nn.Dense(self.embed_dim, num_classes) if num_classes > 0 else Identity() + self.head = mint.nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else Identity() def forward_features(self, x): B = x.shape[0] @@ -332,7 +353,7 @@ def forward_features(self, x): x = blk(x, H, W) x = norm(x) if i != self.num_stages - 1: - x = ops.transpose(ops.reshape(x, (B, H, W, -1)), (0, 3, 1, 2)) + x = mint.permute(mint.reshape(x, (B, H, W, -1)), (0, 3, 1, 2)) return x.mean(axis=1) @@ -348,7 +369,7 @@ def construct(self, x): @register_model def pvt_v2_b0( - pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs + pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs ) -> PyramidVisionTransformerV2: """Get PVTV2-b0 model Refer to the base class "models.PVTv2" for more details. @@ -357,7 +378,7 @@ def pvt_v2_b0( model = PyramidVisionTransformerV2( in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) @@ -367,7 +388,7 @@ def pvt_v2_b0( @register_model def pvt_v2_b1( - pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs + pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs ) -> PyramidVisionTransformerV2: """Get PVTV2-b1 model Refer to the base class "models.PVTv2" for more details. @@ -376,7 +397,7 @@ def pvt_v2_b1( model = PyramidVisionTransformerV2( in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) @@ -386,7 +407,7 @@ def pvt_v2_b1( @register_model def pvt_v2_b2( - pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs + pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs ) -> PyramidVisionTransformerV2: """Get PVTV2-b2 model Refer to the base class "models.PVTv2" for more details. @@ -395,7 +416,7 @@ def pvt_v2_b2( model = PyramidVisionTransformerV2( in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) @@ -405,7 +426,7 @@ def pvt_v2_b2( @register_model def pvt_v2_b3( - pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs + pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs ) -> PyramidVisionTransformerV2: """Get PVTV2-b3 model Refer to the base class "models.PVTv2" for more details. @@ -414,7 +435,7 @@ def pvt_v2_b3( model = PyramidVisionTransformerV2( in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) @@ -423,7 +444,7 @@ def pvt_v2_b3( @register_model def pvt_v2_b4( - pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs + pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs ) -> PyramidVisionTransformerV2: """Get PVTV2-b4 model Refer to the base class "models.PVTv2" for more details. @@ -432,7 +453,7 @@ def pvt_v2_b4( model = PyramidVisionTransformerV2( in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) @@ -441,7 +462,7 @@ def pvt_v2_b4( @register_model def pvt_v2_b5( - pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs + pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs ) -> PyramidVisionTransformerV2: """Get PVTV2-b5 model Refer to the base class "models.PVTv2" for more details. @@ -450,7 +471,7 @@ def pvt_v2_b5( model = PyramidVisionTransformerV2( in_chans=in_channels, num_classes=num_classes, patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) diff --git a/mindcv/models/regnet.py b/mindcv/models/regnet.py index 5f42c5812..643c9d9b2 100644 --- a/mindcv/models/regnet.py +++ b/mindcv/models/regnet.py @@ -8,7 +8,7 @@ import numpy as np import mindspore.common.initializer as init -from mindspore import nn +from mindspore import mint, nn from .helpers import load_pretrained from .layers.pooling import GlobalAvgPooling @@ -85,21 +85,18 @@ def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False): """Helper for building a conv2d layer.""" assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." s, p, g, b = stride, (k - 1) // 2, groups, bias - return nn.Conv2d(w_in, w_out, k, stride=s, pad_mode="pad", padding=p, group=g, has_bias=b) + return mint.nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b) def norm2d(w_in, eps=1e-5, mom=0.9): """Helper for building a norm2d layer.""" - return nn.BatchNorm2d(num_features=w_in, eps=eps, momentum=mom) + return mint.nn.BatchNorm2d(num_features=w_in, eps=eps, momentum=mom) def pool2d(_w_in, k, *, stride=1): """Helper for building a pool2d layer.""" assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." - padding = (k - 1) // 2 - pad2d = nn.Pad(((0, 0), (0, 0), (padding, padding), (padding, padding)), mode="CONSTANT") - max_pool = nn.MaxPool2d(kernel_size=k, stride=stride, pad_mode="valid") - return nn.SequentialCell([pad2d, max_pool]) + return nn.MaxPool2d(k, stride=stride, padding=(k - 1) // 2) def gap2d(keep_dims=False): @@ -109,12 +106,12 @@ def gap2d(keep_dims=False): def linear(w_in, w_out, *, bias=False): """Helper for building a linear layer.""" - return nn.Dense(w_in, w_out, has_bias=bias) + return mint.nn.Linear(w_in, w_out, bias=bias) def activation(): """Helper for building an activation layer.""" - return nn.ReLU() + return mint.nn.ReLU() class ResStemCifar(nn.Cell): @@ -394,15 +391,15 @@ def __init__(self, depths, stem_type, stem_w, block_type, widths, strides, bot_m def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): fan_out = cell.kernel_size[0] * cell.kernel_size[1] * cell.out_channels cell.weight.set_data( init.initializer(init.Normal(sigma=math.sqrt(2.0 / fan_out), mean=0.0), cell.weight.shape, cell.weight.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.Normal(sigma=0.01, mean=0.0), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: diff --git a/mindcv/models/repmlp.py b/mindcv/models/repmlp.py index 0fa9bd6e1..f071db59f 100644 --- a/mindcv/models/repmlp.py +++ b/mindcv/models/repmlp.py @@ -8,7 +8,7 @@ import numpy as np import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn, ops from .helpers import load_pretrained from .registry import register_model @@ -48,7 +48,7 @@ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, group=1, ha d = OrderedDict() conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, pad_mode="pad", padding=padding, group=group, has_bias=has_bias) - bn1 = nn.BatchNorm2d(num_features=out_channels) + bn1 = mint.nn.BatchNorm2d(num_features=out_channels) d["conv"] = conv1 d["bn"] = bn1 result = nn.SequentialCell(d) @@ -59,7 +59,7 @@ def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, group= d = OrderedDict() conv2 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, group=group, has_bias=False) - relu = nn.ReLU() + relu = mint.nn.ReLU() d["conv"] = conv2 d["relu"] = relu result = nn.SequentialCell(d) @@ -69,15 +69,15 @@ def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, group= def fuse_bn(conv_or_fc, bn): std = (bn.running_var + bn.eps).sqrt() t = bn.weight / std - t = t.reshape(-1, 1, 1, 1) + t = mint.reshape(t, (-1, 1, 1, 1)) if len(t) == conv_or_fc.weight.size(0): return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / std else: repeat_times = conv_or_fc.weight.size(0) // len(t) - repeated = t.repeat_interleave(repeat_times, 0) - return conv_or_fc.weight * repeated, (bn.bias - bn.running_mean * bn.weight / std).repeat_interleave( - repeat_times, 0) + repeated = mint.repeat_interleave(t, repeat_times, 0) + return conv_or_fc.weight * repeated, mint.repeat_interleave((bn.bias - bn.running_mean * bn.weight / std), + repeat_times, 0) class GlobalPerceptron(nn.Cell): @@ -90,14 +90,13 @@ def __init__(self, input_channels, internal_neurons): self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=(1, 1), stride=1, has_bias=True) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() + # todo self.sigmoid = nn.Sigmoid() self.input_channels = input_channels - self.shape = ops.Shape() def construct(self, x): - shape = self.shape(x) - pool = nn.AvgPool2d(kernel_size=(shape[2], shape[3]), stride=1) + pool = mint.nn.AvgPool2d(kernel_size=(x.shape[2], x.shape[3]), stride=1) x = pool(x) x = self.fc1(x) x = self.relu(x) @@ -125,9 +124,6 @@ def __init__(self, in_channels, out_channels, self.h, self.w = h, w self.deploy = deploy - self.transpose = ops.Transpose() - self.shape = ops.Shape() - self.reshape = ops.Reshape() assert in_channels == out_channels self.gp = GlobalPerceptron(input_channels=in_channels, internal_neurons=in_channels // globalperceptron_reduce) @@ -152,7 +148,7 @@ def __init__(self, in_channels, out_channels, def partition(self, x, h_parts, w_parts): x = x.reshape(-1, self.C, h_parts, self.h, w_parts, self.w) input_perm = (0, 2, 4, 1, 3, 5) - x = self.transpose(x, input_perm) + x = mint.permute(x, input_perm) return x def partition_affine(self, x, h_parts, w_parts): @@ -167,7 +163,7 @@ def construct(self, inputs): # Global Perceptron global_vec = self.gp(inputs) - origin_shape = self.shape(inputs) + origin_shape = inputs.shape h_parts = origin_shape[2] // self.h w_parts = origin_shape[3] // self.w @@ -179,15 +175,15 @@ def construct(self, inputs): # Local Perceptron if self.reparam_conv_k is not None and not self.deploy: - conv_inputs = self.reshape(partitions, (-1, self.S, self.h, self.w)) + conv_inputs = mint.reshape(partitions, (-1, self.S, self.h, self.w)) conv_out = 0 for k in self.conv_branch_k: conv_out += k(conv_inputs) - conv_out = self.reshape(conv_out, (-1, h_parts, w_parts, self.S, self.h, self.w)) + conv_out = mint.reshape(conv_out, (-1, h_parts, w_parts, self.S, self.h, self.w)) fc3_out += conv_out input_perm = (0, 3, 1, 4, 2, 5) - fc3_out = self.transpose(fc3_out, input_perm) # N, O, h_parts, out_h, w_parts, out_w + fc3_out = mint.permute(fc3_out, input_perm) # N, O, h_parts, out_h, w_parts, out_w out = fc3_out.reshape(*origin_shape) out = out * global_vec return out @@ -223,15 +219,17 @@ def local_inject(self): self.__delattr__("fc3") self.__delattr__("fc3_bn") self.fc3 = nn.Conv2d(self.S * self.h * self.w, self.S * self.h * self.w, 1, 1, 0, has_bias=True, group=self.S) + # todo self.fc3_bn = ops.Identity() self.fc3.weight.data = fc3_weight self.fc3.bias.data = fc3_bias def _convert_conv_to_fc(self, conv_kernel, conv_bias): - I = ops.eye(self.h * self.w).repeat(1, self.S).reshape(self.h * self.w, self.S, self.h, self.w) # noqa: E741 - fc_k = ops.Conv2D(I, conv_kernel, pad=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), group=self.S) - fc_k = fc_k.reshape(self.h * self.w, self.S * self.h * self.w).t() - fc_bias = conv_bias.repeat_interleave(self.h * self.w) + i = mint.reshape(mint.eye(self.h * self.w) + .repeat(1, self.S), (self.h * self.w, self.S, self.h, self.w)) # noqa: E741 + fc_k = ops.Conv2D(i, conv_kernel, pad=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2), group=self.S) + fc_k = mint.transpose(mint.reshape(fc_k, (self.h * self.w, self.S * self.h * self.w)), 1, 0) + fc_bias = mint.repeat_interleave(conv_bias, self.h * self.w) return fc_k, fc_bias @@ -263,8 +261,8 @@ def __init__(self, channels, h, w, reparam_conv_k, globalperceptron_reduce, ffn_ reparam_conv_k=reparam_conv_k, globalperceptron_reduce=globalperceptron_reduce, num_sharesets=num_sharesets, deploy=deploy) self.ffn_block = FFNBlock(channels, channels * ffn_expand) - self.prebn1 = nn.BatchNorm2d(channels).set_train() - self.prebn2 = nn.BatchNorm2d(channels).set_train() + self.prebn1 = mint.nn.BatchNorm2d(channels).set_train() + self.prebn2 = mint.nn.BatchNorm2d(channels).set_train() def construct(self, x): y = x + self.repmlp_block(self.prebn1(x)) @@ -329,12 +327,10 @@ def __init__(self, stride=2, padding=0)) self.stages = nn.CellList(stages) self.embeds = nn.CellList(embeds) - self.head_norm = nn.BatchNorm2d(channels[-1]).set_train() - self.head = nn.Dense(channels[-1], num_class) + self.head_norm = mint.nn.BatchNorm2d(channels[-1]).set_train() + self.head = mint.nn.Linear(channels[-1], num_class) self.use_checkpoint = use_checkpoint - self.shape = ops.Shape() - self.reshape = ops.Reshape() self._initialize_weights() def _initialize_weights(self): @@ -346,7 +342,7 @@ def _initialize_weights(self): cell.weight.set_data(init.initializer(init.Uniform(k), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(init.initializer(init.Uniform(k), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.Linear): k = 1 / cell.in_channels k = k ** 0.5 cell.weight.set_data(init.initializer(init.Uniform(k), cell.weight.shape, cell.weight.dtype)) @@ -364,8 +360,8 @@ def forward_features(self, x: Tensor) -> Tensor: embed = self.embeds[i] x = embed(x) x = self.head_norm(x) - shape = self.shape(x) - pool = nn.AvgPool2d(kernel_size=(shape[2], shape[3])) + shape = x.shape + pool = mint.nn.AvgPool2d(kernel_size=(shape[2], shape[3])) x = pool(x) return x.view(shape[0], -1) diff --git a/mindcv/models/repvgg.py b/mindcv/models/repvgg.py index 6aa89a3a3..e8cf3f7b9 100644 --- a/mindcv/models/repvgg.py +++ b/mindcv/models/repvgg.py @@ -8,7 +8,8 @@ import numpy as np import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops, save_checkpoint +import mindspore.mint.nn.functional as F +from mindspore import Tensor, mint, nn, save_checkpoint from .helpers import build_model_with_cfg from .layers import GlobalAvgPooling, Identity, SqueezeExcite @@ -54,12 +55,13 @@ def _cfg(url="", **kwargs): def conv_bn(in_channels: int, out_channels: int, kernel_size: int, - stride: int, padding: int, group: int = 1) -> nn.SequentialCell: + stride: int, padding: int, groups: int = 1) -> nn.SequentialCell: cell = nn.SequentialCell([ - nn.Conv2d(in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, group=group, pad_mode="pad", - has_bias=False), - nn.BatchNorm2d(num_features=out_channels) + mint.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=groups, bias=False + ), + mint.nn.BatchNorm2d(num_features=out_channels) ]) return cell @@ -69,11 +71,11 @@ class RepVGGBlock(nn.Cell): def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, - group: int = 1, padding_mode: str = "zeros", + groups: int = 1, padding_mode: str = "zeros", deploy: bool = False, use_se: bool = False) -> None: super().__init__() self.deploy = deploy - self.group = group + self.groups = groups self.in_channels = in_channels assert kernel_size == 3 @@ -81,7 +83,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding_11 = padding - kernel_size // 2 - self.nonlinearity = nn.ReLU() + self.nonlinearity = mint.nn.ReLU() if use_se: self.se = SqueezeExcite( @@ -90,18 +92,19 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, self.se = Identity() if deploy: - self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, padding=padding, dilation=dilation, group=group, has_bias=True, - pad_mode=padding_mode) + self.rbr_reparam = mint.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode + ) else: self.rbr_reparam = None - self.rbr_identity = nn.BatchNorm2d( + self.rbr_identity = mint.nn.BatchNorm2d( num_features=in_channels) if out_channels == in_channels and stride == 1 else None self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, padding=padding, group=group) + stride=stride, padding=padding, groups=groups) self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, - padding=padding_11, group=group) + padding=padding_11, groups=groups) def construct(self, inputs: Tensor) -> Tensor: if self.rbr_reparam is not None: @@ -120,16 +123,16 @@ def get_custom_l2(self): k1 = self.rbr_1x1.conv.weight t3 = self.rbr_dense.bn.weight / ( - ops.sqrt((self.rbr_dense.bn.moving_variance + self.rbr_dense.bn.eps))) - t3 = ops.reshape(t3, (-1, 1, 1, 1)) + mint.sqrt((self.rbr_dense.bn.moving_variance + self.rbr_dense.bn.eps))) + t3 = mint.reshape(t3, (-1, 1, 1, 1)) t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.moving_variance + self.rbr_1x1.bn.eps).sqrt())) - t1 = ops.reshape(t1, (-1, 1, 1, 1)) + t1 = mint.reshape(t1, (-1, 1, 1, 1)) - l2_loss_circle = ops.reduce_sum(k3 ** 2) - ops.reduce_sum(k3[:, :, 1:2, 1:2] ** 2) + l2_loss_circle = mint.sum(k3 ** 2) - mint.sum(k3[:, :, 1:2, 1:2] ** 2) eq_kernel = k3[:, :, 1:2, 1:2] * t3 + k1 * t1 - l2_loss_eq_kernel = ops.reduce_sum(eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)) + l2_loss_eq_kernel = mint.sum(eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)) return l2_loss_eq_kernel + l2_loss_circle # This func derives the equivalent kernel and bias in a DIFFERENTIABLE way. @@ -145,7 +148,7 @@ def get_equivalent_kernel_bias(self): def _pad_1x1_to_3x3_tensor(self, kernel1x1): if kernel1x1 is None: return 0 - return ops.pad(kernel1x1, (1, 1, 1, 1)) + return F.pad(kernel1x1, (1, 1, 1, 1)) def _fuse_bn_tensor(self, branch): if branch is None: @@ -158,9 +161,9 @@ def _fuse_bn_tensor(self, branch): beta = branch.bn.beta eps = branch.bn.eps else: - assert isinstance(branch, (nn.BatchNorm2d, nn.SyncBatchNorm)) + assert isinstance(branch, (mint.nn.BatchNorm2d, mint.nn.SyncBatchNorm)) if not hasattr(self, "id_tensor"): - input_dim = self.in_channels // self.group + input_dim = self.in_channels // self.groups kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) for i in range(self.in_channels): kernel_value[i, i % input_dim, 1, 1] = 1 @@ -171,8 +174,8 @@ def _fuse_bn_tensor(self, branch): gamma = branch.gamma beta = branch.beta eps = branch.eps - std = ops.sqrt(moving_variance + eps) - t = ops.reshape(gamma / std, (-1, 1, 1, 1)) + std = mint.sqrt(moving_variance + eps) + t = mint.reshape(gamma / std, (-1, 1, 1, 1)) return kernel * t, beta - moving_mean * gamma / std def switch_to_deploy(self): @@ -180,11 +183,12 @@ def switch_to_deploy(self): if self.rbr_reparam is not None: return kernel, bias = self.get_equivalent_kernel_bias() - self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, - out_channels=self.rbr_dense.conv.out_channels, - kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, - padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, - group=self.rbr_dense.conv.group, has_bias=True, pad_mode="pad") + self.rbr_reparam = mint.nn.Conv2d( + in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels, + kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride, + padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, + groups=self.rbr_dense.conv.groups, bias=True + ) self.rbr_reparam.weight.data = kernel self.rbr_reparam.bias.data = bias for para in self.parameters(): @@ -244,7 +248,7 @@ def __init__(self, num_blocks, num_classes=1000, in_channels=3, width_multiplier int(512 * width_multiplier[3]), num_blocks[3], stride=2) self.feature_info.append(dict(chs=int(512 * width_multiplier[3]), reduction=32, name="stage4")) self.gap = GlobalAvgPooling() - self.linear = nn.Dense(int(512 * width_multiplier[3]), num_classes) + self.linear = mint.nn.Linear(int(512 * width_multiplier[3]), num_classes) self._initialize_weights() def _make_stage(self, planes, num_blocks, stride): @@ -253,7 +257,7 @@ def _make_stage(self, planes, num_blocks, stride): for s in strides: cur_group = self.override_group_map.get(self.cur_layer_idx, 1) blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3, - stride=s, padding=1, group=cur_group, deploy=self.deploy, + stride=s, padding=1, groups=cur_group, deploy=self.deploy, use_se=self.use_se)) self.in_planes = planes self.cur_layer_idx += 1 @@ -263,17 +267,17 @@ def _make_stage(self, planes, num_blocks, stride): def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.HeNormal(mode='fan_out', nonlinearity='relu'), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data( init.initializer('zeros', cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer('ones', cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer('zeros', cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer('ones', cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer('zeros', cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.HeUniform(mode='fan_in', nonlinearity='sigmoid'), cell.weight.shape, cell.weight.dtype)) diff --git a/mindcv/models/res2net.py b/mindcv/models/res2net.py index 54a9990b9..3456fe171 100644 --- a/mindcv/models/res2net.py +++ b/mindcv/models/res2net.py @@ -7,10 +7,10 @@ from typing import List, Optional, Type import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Split +from .layers.pad import Pad from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -62,11 +62,11 @@ def __init__( ) -> None: super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d width = int(math.floor(out_channels * (base_width / 64.0))) * groups - self.conv1 = nn.Conv2d(in_channels, width * scale, kernel_size=1) + self.conv1 = mint.nn.Conv2d(in_channels, width * scale, kernel_size=1, bias=False) self.bn1 = norm(width * scale) if scale == 1: @@ -75,26 +75,24 @@ def __init__( self.nums = scale - 1 if stype == "stage": self.pool = nn.SequentialCell([ - nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"), - nn.AvgPool2d(kernel_size=3, stride=stride), + Pad(pad=(1, 1, 1, 1), mode="constant"), + mint.nn.AvgPool2d(kernel_size=3, stride=stride), ]) self.convs = nn.CellList() self.bns = nn.CellList() for _ in range(self.nums): - self.convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, pad_mode="pad")) + self.convs.append(mint.nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) self.bns.append(norm(width)) - self.conv3 = nn.Conv2d(width * scale, out_channels * self.expansion, - kernel_size=1, stride=1) + self.conv3 = mint.nn.Conv2d(width * scale, out_channels * self.expansion, kernel_size=1, stride=1, bias=False) self.bn3 = norm(out_channels * self.expansion) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.down_sample = down_sample self.stype = stype self.scale = scale self.width = width - self.split = Split(split_size_or_sections=self.width, output_num=self.scale, axis=1) def construct(self, x: Tensor) -> Tensor: identity = x @@ -103,7 +101,7 @@ def construct(self, x: Tensor) -> Tensor: out = self.bn1(out) out = self.relu(out) - spx = self.split(out) + spx = mint.split(out, split_size_or_sections=self.width, dim=1) sp = self.convs[0](spx[0]) sp = self.bns[0](sp) @@ -120,12 +118,12 @@ def construct(self, x: Tensor) -> Tensor: sp = self.bns[i](sp) sp = self.relu(sp) - out = ops.concat((out, sp), axis=1) + out = mint.concat((out, sp), dim=1) if self.scale != 1 and self.stype == "normal": - out = ops.concat((out, spx[self.nums]), axis=1) + out = mint.concat((out, spx[self.nums]), dim=1) elif self.scale != 1 and self.stype == "stage": - out = ops.concat((out, self.pool(spx[self.nums])), axis=1) + out = mint.concat((out, self.pool(spx[self.nums])), dim=1) out = self.conv3(out) out = self.bn3(out) @@ -172,7 +170,7 @@ def __init__( self.version = version if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d self.norm = norm self.num_classes = num_classes @@ -181,27 +179,51 @@ def __init__( self.base_width = base_width self.scale = scale if self.version == "res2net": - self.conv1 = nn.Conv2d(in_channels, self.input_channels, kernel_size=7, - stride=2, padding=3, pad_mode="pad") + self.conv1 = mint.nn.Conv2d( + in_channels, + self.input_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False + ) elif self.version == "res2net_v1b": self.conv1 = nn.SequentialCell([ - nn.Conv2d(in_channels, self.input_channels // 2, kernel_size=3, - stride=2, padding=1, pad_mode="pad"), + mint.nn.Conv2d( + in_channels, + self.input_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False + ), norm(self.input_channels // 2), - nn.ReLU(), - nn.Conv2d(self.input_channels // 2, self.input_channels // 2, kernel_size=3, - stride=1, padding=1, pad_mode="pad"), + mint.nn.ReLU(), + mint.nn.Conv2d( + self.input_channels // 2, + self.input_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False + ), norm(self.input_channels // 2), - nn.ReLU(), - nn.Conv2d(self.input_channels // 2, self.input_channels, kernel_size=3, - stride=1, padding=1, pad_mode="pad"), + mint.nn.ReLU(), + mint.nn.Conv2d( + self.input_channels // 2, + self.input_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False + ), ]) self.bn1 = norm(self.input_channels) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.max_pool = nn.SequentialCell([ - nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"), - nn.MaxPool2d(kernel_size=3, stride=2) + Pad(pad=(1, 1, 1, 1), mode="constant"), + mint.nn.MaxPool2d(kernel_size=3, stride=2) ]) self.layer1 = self._make_layer(block, 64, layer_nums[0]) self.layer2 = self._make_layer(block, 128, layer_nums[1], stride=2) @@ -210,13 +232,13 @@ def __init__( self.pool = GlobalAvgPooling() self.num_features = 512 * block.expansion - self.classifier = nn.Dense(self.num_features, num_classes) + self.classifier = mint.nn.Linear(self.num_features, num_classes) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.HeNormal(math.sqrt(5), mode="fan_out", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype)) @@ -224,10 +246,10 @@ def _initialize_weights(self) -> None: cell.bias.set_data( init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="leaky_relu"), cell.weight.shape, cell.weight.dtype)) @@ -246,13 +268,29 @@ def _make_layer( if stride != 1 or self.input_channels != channels * block.expansion: if stride == 1 or self.version == "res2net": down_sample = nn.SequentialCell([ - nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride), + mint.nn.Conv2d( + self.input_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), self.norm(channels * block.expansion) ]) else: down_sample = nn.SequentialCell([ - nn.AvgPool2d(kernel_size=stride, stride=stride, pad_mode="same"), - nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=1), + mint.nn.AvgPool2d( + kernel_size=stride, + stride=stride, + padding=0 + ), + mint.nn.Conv2d( + self.input_channels, + channels * block.expansion, + kernel_size=1, + stride=1, + bias=False + ), self.norm(channels * block.expansion) ]) diff --git a/mindcv/models/resnest.py b/mindcv/models/resnest.py index 6a28e30e3..c0195ec5e 100644 --- a/mindcv/models/resnest.py +++ b/mindcv/models/resnest.py @@ -6,10 +6,9 @@ from typing import List, Optional, Type import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import build_model_with_cfg, make_divisible -from .layers.compatibility import Dropout from .layers.identity import Identity from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -54,18 +53,17 @@ def __init__( super(RadixSoftmax, self).__init__() self.radix = radix self.cardinality = cardinality - self.softmax = ops.Softmax(axis=1) - self.sigmoid = nn.Sigmoid() + self.softmax = mint.nn.Softmax(dim=1) def construct(self, x: Tensor) -> Tensor: batch = x.shape[0] if self.radix > 1: - x = ops.reshape(x, (batch, self.cardinality, self.radix, -1)) - x = ops.transpose(x, (0, 2, 1, 3)) + x = mint.reshape(x, (batch, self.cardinality, self.radix, -1)) + x = mint.permute(x, (0, 2, 1, 3)) x = self.softmax(x) - x = ops.reshape(x, (batch, -1)) + x = mint.reshape(x, (batch, -1)) else: - x = self.sigmoid() + x = mint.sigmoid(x) return x @@ -87,7 +85,7 @@ def __init__( rd_ratio: float = 0.25, rd_channels: Optional[int] = None, rd_divisor: int = 8, - act_layer: nn.Cell = nn.ReLU, + act_layer: nn.Cell = mint.nn.ReLU, norm_layer: Optional[nn.Cell] = None, ) -> None: super(SplitAttn, self).__init__() @@ -102,15 +100,16 @@ def __init__( padding = kernel_size // 2 if padding is None else padding - self.conv = nn.Conv2d(in_channels, mid_chs, kernel_size=kernel_size, stride=stride, - pad_mode="pad", padding=padding, dilation=dilation, - group=group * radix, has_bias=bias) + self.conv = mint.nn.Conv2d( + in_channels, mid_chs, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=group * radix, bias=bias + ) self.bn0 = norm_layer(mid_chs) if norm_layer else Identity() self.act0 = act_layer() - self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, group=group, has_bias=True) + self.fc1 = mint.nn.Conv2d(out_channels, attn_chs, 1, groups=group, bias=True) self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() self.act1 = act_layer() - self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, group=group, has_bias=True) + self.fc2 = mint.nn.Conv2d(attn_chs, mid_chs, 1, groups=group, bias=True) self.rsoftmax = RadixSoftmax(radix, group) self.pool = GlobalAvgPooling(keep_dims=True) @@ -121,8 +120,8 @@ def construct(self, x: Tensor) -> Tensor: B, RC, H, W = x.shape if self.radix > 1: - x = ops.reshape(x, (B, self.radix, RC // self.radix, H, W)) - x_gap = x.sum(axis=1) + x = mint.reshape(x, (B, self.radix, RC // self.radix, H, W)) + x_gap = mint.sum(x, dim=1) else: x_gap = x x_gap = self.pool(x_gap) @@ -132,10 +131,10 @@ def construct(self, x: Tensor) -> Tensor: x_attn = self.fc2(x_gap) x_attn = self.rsoftmax(x_attn) - x_attn = ops.reshape(x_attn, (B, -1, 1, 1)) + x_attn = mint.reshape(x_attn, (B, -1, 1, 1)) if self.radix > 1: - out = x * ops.reshape(x_attn, (B, self.radix, RC // self.radix, 1, 1)) - out = out.sum(axis=1) + out = x * mint.reshape(x_attn, (B, self.radix, RC // self.radix, 1, 1)) + out = mint.sum(out, dim=1) else: out = x * x_attn @@ -164,14 +163,14 @@ def __init__( ) -> None: super(Bottleneck, self).__init__() group_width = int(planes * (bottleneck_width / 64.0)) * cardinality - self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, has_bias=False) + self.conv1 = mint.nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) self.bn1 = norm_layer(group_width) self.radix = radix self.avd = avd and (stride > 1 or is_first) self.avd_first = avd_first if self.avd: - self.avd_layer = nn.AvgPool2d(3, stride, pad_mode="same") + self.avd_layer = mint.nn.AvgPool2d(3, stride, padding=1) stride = 1 if radix >= 1: @@ -179,15 +178,16 @@ def __init__( padding=dilation, dilation=dilation, group=cardinality, bias=False, radix=radix, norm_layer=norm_layer) else: - self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, - pad_mode="pad", padding=dilation, dilation=dilation, - group=cardinality, has_bias=False) + self.conv2 = mint.nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, groups=cardinality, bias=False + ) self.bn2 = norm_layer(group_width) - self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, has_bias=False) + self.conv3 = mint.nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) self.bn3 = norm_layer(planes * 4) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.downsample = downsample self.dilation = dilation self.stride = stride @@ -263,7 +263,7 @@ def __init__( avd: bool = False, avd_first: bool = False, drop_rate: float = 0.0, - norm_layer: nn.Cell = nn.BatchNorm2d, + norm_layer: nn.Cell = mint.nn.BatchNorm2d, ) -> None: super(ResNeSt, self).__init__() self.cardinality = group @@ -278,25 +278,21 @@ def __init__( if deep_stem: self.conv1 = nn.SequentialCell([ - nn.Conv2d(3, stem_width, kernel_size=3, stride=2, pad_mode="pad", - padding=1, has_bias=False), + mint.nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False), norm_layer(stem_width), - nn.ReLU(), - nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, pad_mode="pad", - padding=1, has_bias=False), + mint.nn.ReLU(), + mint.nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False), norm_layer(stem_width), - nn.ReLU(), - nn.Conv2d(stem_width, stem_width * 2, kernel_size=3, stride=1, pad_mode="pad", - padding=1, has_bias=False), + mint.nn.ReLU(), + mint.nn.Conv2d(stem_width, stem_width * 2, kernel_size=3, stride=1, padding=1, bias=False), ]) else: - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, pad_mode="pad", padding=3, - has_bias=False) + self.conv1 = mint.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.feature_info = [dict(chs=self.inplanes, reduction=2, name="relu")] - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.maxpool = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer, is_first=False) self.feature_info.append(dict(chs=block.expansion * 64, reduction=4, name='layer1')) @@ -320,15 +316,15 @@ def __init__( self.feature_info.append(dict(chs=block.expansion * 512, reduction=32, name='layer4')) self.avgpool = GlobalAvgPooling() - self.drop = Dropout(p=drop_rate) if drop_rate > 0.0 else None - self.fc = nn.Dense(512 * block.expansion, num_classes) + self.drop = mint.nn.Dropout(p=drop_rate) if drop_rate > 0.0 else None + self.fc = mint.nn.Linear(512 * block.expansion, num_classes) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer( init.HeNormal(mode="fan_out", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype @@ -336,10 +332,10 @@ def _initialize_weights(self) -> None: ) if cell.bias is not None: cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer( init.HeUniform(mode="fan_in", nonlinearity="sigmoid"), cell.weight.shape, cell.weight.dtype @@ -363,16 +359,17 @@ def _make_layer( down_layers = [] if self.avg_down: if dilation == 1: - down_layers.append(nn.AvgPool2d(kernel_size=stride, stride=stride, pad_mode="valid")) + down_layers.append(mint.nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0)) else: - down_layers.append(nn.AvgPool2d(kernel_size=1, stride=1, pad_mode="valid")) + down_layers.append(mint.nn.AvgPool2d(kernel_size=1, stride=1, padding=0)) - down_layers.append(nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, - stride=1, has_bias=False)) + down_layers.append( + mint.nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False) + ) else: down_layers.append( - nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, - has_bias=False)) + mint.nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False) + ) down_layers.append(norm_layer(planes * block.expansion)) downsample = nn.SequentialCell(down_layers) diff --git a/mindcv/models/resnet.py b/mindcv/models/resnet.py index 5649af875..b797aa26e 100644 --- a/mindcv/models/resnet.py +++ b/mindcv/models/resnet.py @@ -6,7 +6,7 @@ from typing import List, Optional, Type, Union import mindspore.common.initializer as init -from mindspore import Tensor, nn +from mindspore import Tensor, mint, nn from .helpers import build_model_with_cfg from .layers.pooling import GlobalAvgPooling @@ -71,16 +71,14 @@ def __init__( ) -> None: super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d assert groups == 1, "BasicBlock only supports groups=1" assert base_width == 64, "BasicBlock only supports base_width=64" - self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, - stride=stride, padding=1, pad_mode="pad") + self.conv1 = mint.nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = norm(channels) - self.relu = nn.ReLU() - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, - stride=1, padding=1, pad_mode="pad") + self.relu = mint.nn.ReLU() + self.conv2 = mint.nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = norm(channels) self.down_sample = down_sample @@ -122,19 +120,17 @@ def __init__( ) -> None: super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d width = int(channels * (base_width / 64.0)) * groups - self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1) + self.conv1 = mint.nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False) self.bn1 = norm(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, - padding=1, pad_mode="pad", group=groups) + self.conv2 = mint.nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) self.bn2 = norm(width) - self.conv3 = nn.Conv2d(width, channels * self.expansion, - kernel_size=1, stride=1) + self.conv3 = mint.nn.Conv2d(width, channels * self.expansion, kernel_size=1, stride=1, bias=False) self.bn3 = norm(channels * self.expansion) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.down_sample = down_sample def construct(self, x: Tensor) -> Tensor: @@ -186,19 +182,18 @@ def __init__( ) -> None: super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d self.norm: nn.Cell = norm # add type hints to make pylint happy self.input_channels = 64 self.groups = groups self.base_with = base_width - self.conv1 = nn.Conv2d(in_channels, self.input_channels, kernel_size=7, - stride=2, pad_mode="pad", padding=3) + self.conv1 = mint.nn.Conv2d(in_channels, self.input_channels, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm(self.input_channels) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.feature_info = [dict(chs=self.input_channels, reduction=2, name="relu")] - self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.max_pool = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.feature_info.append(dict(chs=block.expansion * 64, reduction=4, name="layer1")) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) @@ -210,24 +205,24 @@ def __init__( self.pool = GlobalAvgPooling() self.num_features = 512 * block.expansion - self.classifier = nn.Dense(self.num_features, num_classes) + self.classifier = mint.nn.Linear(self.num_features, num_classes) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.HeNormal(mode='fan_out', nonlinearity='relu'), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data( init.initializer('zeros', cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer('ones', cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer('zeros', cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer('ones', cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer('zeros', cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.HeUniform(mode='fan_in', nonlinearity='sigmoid'), cell.weight.shape, cell.weight.dtype)) @@ -246,7 +241,13 @@ def _make_layer( if stride != 1 or self.input_channels != channels * block.expansion: down_sample = nn.SequentialCell([ - nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride), + mint.nn.Conv2d( + self.input_channels, + channels * block.expansion, + kernel_size=1, + stride=stride, + bias=False + ), self.norm(channels * block.expansion) ]) diff --git a/mindcv/models/resnetv2.py b/mindcv/models/resnetv2.py index 144b1d580..a7566744b 100644 --- a/mindcv/models/resnetv2.py +++ b/mindcv/models/resnetv2.py @@ -5,7 +5,7 @@ from typing import Optional -from mindspore import Tensor, nn +from mindspore import Tensor, mint, nn from .helpers import load_pretrained from .registry import register_model @@ -47,22 +47,20 @@ def __init__(self, ) -> None: super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d width = int(channels * (base_width / 64.0)) * groups self.bn1 = norm(in_channels) - self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1) + self.conv1 = mint.nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False) self.bn2 = norm(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, - padding=1, pad_mode='pad', group=groups) + self.conv2 = mint.nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) self.bn3 = norm(width) - self.conv3 = nn.Conv2d(width, channels * self.expansion, - kernel_size=1, stride=1) + self.conv3 = mint.nn.Conv2d(width, channels * self.expansion, kernel_size=1, stride=1, bias=False) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.down_sample = down_sample def construct(self, x: Tensor) -> Tensor: diff --git a/mindcv/models/rexnet.py b/mindcv/models/rexnet.py index ba2850649..e75a87cb3 100644 --- a/mindcv/models/rexnet.py +++ b/mindcv/models/rexnet.py @@ -7,11 +7,11 @@ from typing import Any import mindspore.common.initializer as init +import mindspore.mint as mint import mindspore.nn as nn from .helpers import build_model_with_cfg, make_divisible from .layers import Conv2dNormActivation, DropPath, GlobalAvgPooling, SqueezeExcite -from .layers.compatibility import Dropout from .registry import register_model __all__ = [ @@ -56,7 +56,8 @@ def __init__( use_se=True, se_ratio=1 / 12, ch_div=1, - act_layer=nn.SiLU, + act_layer=mint.nn.SiLU, + # todo dw_act_layer=nn.ReLU6, drop_path=None, **kwargs, @@ -79,7 +80,7 @@ def __init__( if use_se: self.se = SqueezeExcite(dw_channels, rd_channels=make_divisible(int(dw_channels * se_ratio), ch_div), - norm=nn.BatchNorm2d) + norm=mint.nn.BatchNorm2d) else: self.se = None self.act_dw = dw_act_layer() @@ -136,7 +137,8 @@ def __init__( drop_rate=0.2, drop_path_rate=0.0, ch_div=1, - act_layer=nn.SiLU, + act_layer=mint.nn.SiLU, + # todo dw_act_layer=nn.ReLU6, cls_useconv=False, ): @@ -213,18 +215,18 @@ def __init__( self.features = nn.SequentialCell(*features) if self.useconv: self.cls = nn.SequentialCell( - Dropout(p=drop_rate), - nn.Conv2d(pen_channels, num_classes, 1, has_bias=True)) + mint.nn.Dropout(p=drop_rate), + mint.nn.Conv2d(pen_channels, num_classes, 1, has_bias=True)) else: self.cls = nn.SequentialCell( - Dropout(p=drop_rate), - nn.Dense(pen_channels, num_classes)) + mint.nn.Dropout(p=drop_rate), + mint.nn.Linear(pen_channels, num_classes)) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, (nn.Conv2d, nn.Dense)): + if isinstance(cell, (mint.nn.Conv2d, mint.nn.Linear)): cell.weight.set_data( init.initializer(init.HeUniform(math.sqrt(5), mode="fan_in", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype)) @@ -240,10 +242,10 @@ def forward_features(self, x): def forward_head(self, x): if not self.useconv: - x = x.reshape((x.shape[0], -1)) + x = mint.reshape(x, (x.shape[0], -1)) x = self.cls(x) else: - x = self.cls(x).reshape((x.shape[0], -1)) + x = mint.reshape(self.cls(x), (x.shape[0], -1)) return x def construct(self, x): diff --git a/mindcv/models/senet.py b/mindcv/models/senet.py index b35030ad4..0248bdb6b 100644 --- a/mindcv/models/senet.py +++ b/mindcv/models/senet.py @@ -7,10 +7,9 @@ from typing import List, Optional, Type, Union import mindspore.common.initializer as init -from mindspore import Tensor, nn +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Dropout from .layers.pooling import GlobalAvgPooling from .layers.squeeze_excite import SqueezeExciteV2 from .registry import register_model @@ -102,16 +101,15 @@ def __init__( downsample: Optional[nn.SequentialCell] = None, ) -> None: super(SEBottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_channels, channels * 2, kernel_size=1, pad_mode="pad", - padding=0, has_bias=False) - self.bn1 = nn.BatchNorm2d(channels * 2) - self.conv2 = nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=stride, - pad_mode="pad", padding=1, group=group, has_bias=False) - self.bn2 = nn.BatchNorm2d(channels * 4) - self.conv3 = nn.Conv2d(channels * 4, channels * 4, kernel_size=1, pad_mode="pad", - padding=0, has_bias=False) - self.bn3 = nn.BatchNorm2d(channels * 4) - self.relu = nn.ReLU() + self.conv1 = mint.nn.Conv2d(in_channels, channels * 2, kernel_size=1, padding=0, bias=False) + self.bn1 = mint.nn.BatchNorm2d(channels * 2) + self.conv2 = mint.nn.Conv2d( + channels * 2, channels * 4, kernel_size=3, stride=stride, padding=1, groups=group, bias=False + ) + self.bn2 = mint.nn.BatchNorm2d(channels * 4) + self.conv3 = mint.nn.Conv2d(channels * 4, channels * 4, kernel_size=1, padding=0, bias=False) + self.bn3 = mint.nn.BatchNorm2d(channels * 4) + self.relu = mint.nn.ReLU() self.se_module = SqueezeExciteV2(channels * 4, rd_ratio=1.0 / reduction) self.downsample = downsample self.stride = stride @@ -135,16 +133,15 @@ def __init__( downsample: Optional[nn.SequentialCell] = None, ) -> None: super(SEResNetBottleneck, self).__init__() - self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=1, pad_mode="pad", - padding=0, has_bias=False) - self.bn1 = nn.BatchNorm2d(channels) - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride, pad_mode="pad", - padding=1, group=group, has_bias=False) - self.bn2 = nn.BatchNorm2d(channels) - self.conv3 = nn.Conv2d(channels, channels * 4, kernel_size=1, pad_mode="pad", padding=0, - has_bias=False) - self.bn3 = nn.BatchNorm2d(channels * 4) - self.relu = nn.ReLU() + self.conv1 = mint.nn.Conv2d(in_channels, channels, kernel_size=1, padding=0, bias=False) + self.bn1 = mint.nn.BatchNorm2d(channels) + self.conv2 = mint.nn.Conv2d( + channels, channels, kernel_size=3, stride=stride, padding=1, groups=group, bias=False + ) + self.bn2 = mint.nn.BatchNorm2d(channels) + self.conv3 = mint.nn.Conv2d(channels, channels * 4, kernel_size=1, padding=0, bias=False) + self.bn3 = mint.nn.BatchNorm2d(channels * 4) + self.relu = mint.nn.ReLU() self.se_module = SqueezeExciteV2(channels * 4, rd_ratio=1.0 / reduction) self.downsample = downsample self.stride = stride @@ -169,16 +166,13 @@ def __init__( ) -> None: super(SEResNeXtBottleneck, self).__init__() width = math.floor(channels * (base_width / 64)) * group - self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, pad_mode="pad", - padding=0, has_bias=False) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, pad_mode="pad", - padding=1, group=group, has_bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, channels * 4, kernel_size=1, pad_mode="pad", padding=0, - has_bias=False) - self.bn3 = nn.BatchNorm2d(channels * 4) - self.relu = nn.ReLU() + self.conv1 = mint.nn.Conv2d(in_channels, width, kernel_size=1, stride=1, padding=0, bias=False) + self.bn1 = mint.nn.BatchNorm2d(width) + self.conv2 = mint.nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=group, bias=False) + self.bn2 = mint.nn.BatchNorm2d(width) + self.conv3 = mint.nn.Conv2d(width, channels * 4, kernel_size=1, padding=0, bias=False) + self.bn3 = mint.nn.BatchNorm2d(channels * 4) + self.relu = mint.nn.ReLU() self.se_module = SqueezeExciteV2(channels * 4, rd_ratio=1.0 / reduction) self.downsample = downsample self.stride = stride @@ -201,13 +195,11 @@ def __init__( downsample: Optional[nn.SequentialCell] = None, ) -> None: super(SEResNetBlock, self).__init__() - self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, pad_mode="pad", - padding=1, has_bias=False) - self.bn1 = nn.BatchNorm2d(channels) - self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, pad_mode="pad", padding=1, - group=group, has_bias=False) - self.bn2 = nn.BatchNorm2d(channels) - self.relu = nn.ReLU() + self.conv1 = mint.nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = mint.nn.BatchNorm2d(channels) + self.conv2 = mint.nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=group, has_bias=False) + self.bn2 = mint.nn.BatchNorm2d(channels) + self.relu = mint.nn.ReLU() self.se_module = SqueezeExciteV2(channels, rd_ratio=1.0 / reduction) self.downsample = downsample self.stride = stride @@ -269,24 +261,23 @@ def __init__( self.drop_rate = drop_rate if input3x3: self.layer0 = nn.SequentialCell([ - nn.Conv2d(in_channels, 64, 3, stride=2, pad_mode="pad", padding=1, has_bias=False), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Conv2d(64, 64, 3, stride=1, pad_mode="pad", padding=1, has_bias=False), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Conv2d(64, inplanes, 3, stride=1, pad_mode="pad", padding=1, has_bias=False), - nn.BatchNorm2d(inplanes), - nn.ReLU() + mint.nn.Conv2d(in_channels, 64, 3, stride=2, padding=1, bias=False), + mint.nn.BatchNorm2d(64), + mint.nn.ReLU(), + mint.nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False), + mint.nn.BatchNorm2d(64), + mint.nn.ReLU(), + mint.nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False), + mint.nn.BatchNorm2d(inplanes), + mint.nn.ReLU() ]) else: self.layer0 = nn.SequentialCell([ - nn.Conv2d(in_channels, inplanes, kernel_size=7, stride=2, pad_mode="pad", - padding=3, has_bias=False), - nn.BatchNorm2d(inplanes), - nn.ReLU() + mint.nn.Conv2d(in_channels, inplanes, kernel_size=7, stride=2, padding=3, bias=False), + mint.nn.BatchNorm2d(inplanes), + mint.nn.ReLU() ]) - self.pool0 = nn.MaxPool2d(3, stride=2, pad_mode="same") + self.pool0 = mint.nn.MaxPool2d(3, stride=2, padding=1) self.layer1 = self._make_layer(block, planes=64, blocks=layers[0], group=group, reduction=reduction, downsample_kernel_size=1, @@ -311,8 +302,8 @@ def __init__( self.pool = GlobalAvgPooling() if self.drop_rate > 0.: - self.dropout = Dropout(p=self.drop_rate) - self.classifier = nn.Dense(self.num_features, self.num_classes) + self.dropout = mint.nn.Dropout(p=self.drop_rate) + self.classifier = mint.nn.Linear(self.num_features, self.num_classes) self._initialize_weights() @@ -330,9 +321,11 @@ def _make_layer( downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.SequentialCell([ - nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, - stride=stride, pad_mode="pad", padding=downsample_padding, has_bias=False), - nn.BatchNorm2d(planes * block.expansion) + mint.nn.Conv2d( + self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, + stride=stride, padding=downsample_padding, bias=False + ), + mint.nn.BatchNorm2d(planes * block.expansion) ]) layers = [block(self.inplanes, planes, group, reduction, stride, downsample)] @@ -345,17 +338,17 @@ def _make_layer( def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.HeNormal(mode="fan_out", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data( init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.HeUniform(mode="fan_in", nonlinearity="sigmoid"), cell.weight.shape, cell.weight.dtype)) diff --git a/mindcv/models/shufflenetv1.py b/mindcv/models/shufflenetv1.py index 1e316458d..b3c42538c 100644 --- a/mindcv/models/shufflenetv1.py +++ b/mindcv/models/shufflenetv1.py @@ -4,10 +4,9 @@ """ import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Split from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -61,37 +60,35 @@ class GroupConv(nn.Cell): out_channels (int): Output channels of feature map. kernel_size (int): Size of convolution kernel. stride (int): Stride size for the group convolution layer. - pad_mode (str): Specifies padding mode. pad (int): The number of padding on the height and width directions of the input. groups (int): Splits filter into groups, `in_channels` and `out_channels` must be divisible by `group`. has_bias (bool): Whether the Conv2d layer has a bias parameter. """ - def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False): + def __init__(self, in_channels, out_channels, kernel_size, stride, pad=0, groups=1, has_bias=False): super(GroupConv, self).__init__() assert in_channels % groups == 0 and out_channels % groups == 0 + self.in_channels = in_channels self.groups = groups self.convs = nn.CellList() - self.split = Split(split_size_or_sections=in_channels // groups, output_num=self.groups, axis=1) for _ in range(groups): self.convs.append( - nn.Conv2d( + mint.nn.Conv2d( in_channels // groups, out_channels // groups, kernel_size=kernel_size, stride=stride, - has_bias=has_bias, - padding=pad, - pad_mode=pad_mode, + bias=has_bias, + padding=pad ) ) def construct(self, x): - features = self.split(x) + features = mint.split(x, split_size_or_sections=self.in_channels // self.groups, dim=1) outputs = () for i in range(self.groups): outputs = outputs + (self.convs[i](features[i]),) - out = ops.concat(outputs, axis=1) + out = mint.concat(outputs, dim=1) return out @@ -122,37 +119,36 @@ def __init__( out_channels=mid_channels, kernel_size=1, stride=1, - pad_mode="pad", pad=0, groups=1 if first_group else group, ), - nn.BatchNorm2d(mid_channels), - nn.ReLU(), + mint.nn.BatchNorm2d(mid_channels), + mint.nn.ReLU(), ] branch_main_2 = [ # dw - nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, pad_mode="pad", padding=1, - group=mid_channels), - nn.BatchNorm2d(mid_channels), + mint.nn.Conv2d( + mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, groups=mid_channels, bias=False + ), + mint.nn.BatchNorm2d(mid_channels), # pw-linear GroupConv( in_channels=mid_channels, out_channels=out_channels, kernel_size=1, stride=1, - pad_mode="pad", pad=0, groups=group, ), - nn.BatchNorm2d(out_channels), + mint.nn.BatchNorm2d(out_channels), ] self.branch_main_1 = nn.SequentialCell(branch_main_1) self.branch_main_2 = nn.SequentialCell(branch_main_2) if stride == 2: - self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode="same") + self.branch_proj = mint.nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() def construct(self, x: Tensor) -> Tensor: identify = x @@ -163,7 +159,7 @@ def construct(self, x: Tensor) -> Tensor: if self.stride == 1: out = self.relu(identify + x) else: - out = self.relu(ops.concat((self.branch_proj(identify), x), axis=1)) + out = self.relu(mint.concat((self.branch_proj(identify), x), dim=1)) return out @@ -171,9 +167,9 @@ def channel_shuffle(self, x: Tensor) -> Tensor: batch_size, num_channels, height, width = x.shape group_channels = num_channels // self.group - x = ops.reshape(x, (batch_size, group_channels, self.group, height, width)) - x = ops.transpose(x, (0, 2, 1, 3, 4)) - x = ops.reshape(x, (batch_size, num_channels, height, width)) + x = mint.reshape(x, (batch_size, group_channels, self.group, height, width)) + x = mint.permute(x, (0, 2, 1, 3, 4)) + x = mint.reshape(x, (batch_size, num_channels, height, width)) return x @@ -224,11 +220,11 @@ def __init__( # building first layer input_channel = self.stage_out_channels[1] self.first_conv = nn.SequentialCell( - nn.Conv2d(in_channels, input_channel, kernel_size=3, stride=2, pad_mode="pad", padding=1), - nn.BatchNorm2d(input_channel), - nn.ReLU(), + mint.nn.Conv2d(in_channels, input_channel, kernel_size=3, stride=2, padding=1, bias=False), + mint.nn.BatchNorm2d(input_channel), + mint.nn.ReLU(), ) - self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.max_pool = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) features = [] for idxstage, numrepeat in enumerate(self.stage_repeats): @@ -243,13 +239,13 @@ def __init__( self.features = nn.SequentialCell(features) self.global_pool = GlobalAvgPooling() - self.classifier = nn.Dense(self.stage_out_channels[-1], num_classes, has_bias=False) + self.classifier = mint.nn.Linear(self.stage_out_channels[-1], num_classes, bias=False) self._initialize_weights() def _initialize_weights(self): """Initialize weights for cells.""" for name, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): if "first" in name: cell.weight.set_data( init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype)) @@ -260,7 +256,7 @@ def _initialize_weights(self): if cell.bias is not None: cell.bias.set_data( init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: diff --git a/mindcv/models/shufflenetv2.py b/mindcv/models/shufflenetv2.py index bc49fc2ff..5b290edb6 100644 --- a/mindcv/models/shufflenetv2.py +++ b/mindcv/models/shufflenetv2.py @@ -6,7 +6,7 @@ from typing import Tuple import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained from .layers.pooling import GlobalAvgPooling @@ -65,30 +65,34 @@ def __init__( out_channels = out_channels - in_channels branch_main = [ # pw - nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1), - nn.BatchNorm2d(mid_channels), - nn.ReLU(), + mint.nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, bias=False), + mint.nn.BatchNorm2d(mid_channels), + mint.nn.ReLU(), # dw - nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride, - pad_mode="pad", padding=pad, group=mid_channels), - nn.BatchNorm2d(mid_channels), + mint.nn.Conv2d( + mid_channels, mid_channels, kernel_size=kernel_size, stride=stride, + padding=pad, groups=mid_channels, bias=False + ), + mint.nn.BatchNorm2d(mid_channels), # pw-linear - nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1), - nn.BatchNorm2d(out_channels), - nn.ReLU(), + mint.nn.Conv2d(mid_channels, out_channels, kernel_size=1, stride=1, bias=False), + mint.nn.BatchNorm2d(out_channels), + mint.nn.ReLU(), ] self.branch_main = nn.SequentialCell(branch_main) if stride == 2: branch_proj = [ # dw - nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, - pad_mode="pad", padding=pad, group=in_channels), - nn.BatchNorm2d(in_channels), + mint.nn.Conv2d( + in_channels, in_channels, kernel_size=kernel_size, stride=stride, + padding=pad, groups=in_channels, bias=False + ), + mint.nn.BatchNorm2d(in_channels), # pw-linear - nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1), - nn.BatchNorm2d(in_channels), - nn.ReLU(), + mint.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, bias=False), + mint.nn.BatchNorm2d(in_channels), + mint.nn.ReLU(), ] self.branch_proj = nn.SequentialCell(branch_proj) else: @@ -97,20 +101,20 @@ def __init__( def construct(self, old_x: Tensor) -> Tensor: if self.stride == 1: x_proj, x = self.channel_shuffle(old_x) - return ops.concat((x_proj, self.branch_main(x)), axis=1) + return mint.concat((x_proj, self.branch_main(x)), dim=1) if self.stride == 2: x_proj = old_x x = old_x - return ops.concat((self.branch_proj(x_proj), self.branch_main(x)), axis=1) + return mint.concat((self.branch_proj(x_proj), self.branch_main(x)), dim=1) return None @staticmethod def channel_shuffle(x: Tensor) -> Tuple[Tensor, Tensor]: batch_size, num_channels, height, width = x.shape - x = ops.reshape(x, (batch_size * num_channels // 2, 2, height * width,)) - x = ops.transpose(x, (1, 0, 2,)) - x = ops.reshape(x, (2, -1, num_channels // 2, height, width,)) + x = mint.reshape(x, (batch_size * num_channels // 2, 2, height * width,)) + x = mint.permute(x, (1, 0, 2,)) + x = mint.reshape(x, (2, -1, num_channels // 2, height, width,)) return x[0], x[1] @@ -148,12 +152,11 @@ def __init__( # building first layer input_channel = self.stage_out_channels[1] self.first_conv = nn.SequentialCell([ - nn.Conv2d(in_channels, input_channel, kernel_size=3, stride=2, - pad_mode="pad", padding=1), - nn.BatchNorm2d(input_channel), - nn.ReLU(), + mint.nn.Conv2d(in_channels, input_channel, kernel_size=3, stride=2, padding=1, bias=False), + mint.nn.BatchNorm2d(input_channel), + mint.nn.ReLU(), ]) - self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.max_pool = mint.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.features = [] for idxstage, numrepeat in enumerate(self.stage_repeats): @@ -170,18 +173,18 @@ def __init__( self.features = nn.SequentialCell(self.features) self.conv_last = nn.SequentialCell([ - nn.Conv2d(input_channel, self.stage_out_channels[-1], kernel_size=1, stride=1), - nn.BatchNorm2d(self.stage_out_channels[-1]), - nn.ReLU() + mint.nn.Conv2d(input_channel, self.stage_out_channels[-1], kernel_size=1, stride=1, bias=False), + mint.nn.BatchNorm2d(self.stage_out_channels[-1]), + mint.nn.ReLU() ]) self.pool = GlobalAvgPooling() - self.classifier = nn.Dense(self.stage_out_channels[-1], num_classes, has_bias=False) + self.classifier = mint.nn.Linear(self.stage_out_channels[-1], num_classes, bias=False) self._initialize_weights() def _initialize_weights(self): """Initialize weights for cells.""" for name, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): if "first" in name: cell.weight.set_data( init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype)) @@ -192,7 +195,7 @@ def _initialize_weights(self): if cell.bias is not None: cell.bias.set_data( init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: diff --git a/mindcv/models/sknet.py b/mindcv/models/sknet.py index c1500c2cb..7689fd3c3 100644 --- a/mindcv/models/sknet.py +++ b/mindcv/models/sknet.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Type, Union -from mindspore import Tensor, nn +from mindspore import Tensor, mint, nn from .helpers import load_pretrained from .layers.selective_kernel import SelectiveKernel @@ -58,7 +58,7 @@ def __init__( ): super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d if sk_kwargs is None: sk_kwargs = {} @@ -69,11 +69,11 @@ def __init__( self.conv1 = SelectiveKernel( in_channels, out_channels, stride=stride, **sk_kwargs) self.conv2 = nn.SequentialCell([ - nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=3, padding=1, pad_mode="pad"), + mint.nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=3, padding=1, bias=False), norm(out_channels * self.expansion) ]) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.down_sample = down_sample def construct(self, x: Tensor) -> Tensor: @@ -107,24 +107,24 @@ def __init__( ): super().__init__() if norm is None: - norm = nn.BatchNorm2d + norm = mint.nn.BatchNorm2d if sk_kwargs is None: sk_kwargs = {} width = int(out_channels * (base_width / 64.0)) * groups self.conv1 = nn.SequentialCell([ - nn.Conv2d(in_channels, width, kernel_size=1), + mint.nn.Conv2d(in_channels, width, kernel_size=1, bias=False), norm(width) ]) self.conv2 = SelectiveKernel( width, width, stride=stride, groups=groups, **sk_kwargs) self.conv3 = nn.SequentialCell([ - nn.Conv2d(width, out_channels * self.expansion, kernel_size=1), + mint.nn.Conv2d(width, out_channels * self.expansion, kernel_size=1, bias=False), norm(out_channels * self.expansion) ]) - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() self.down_sample = down_sample def construct(self, x: Tensor) -> Tensor: @@ -181,7 +181,9 @@ def _make_layer( if stride != 1 or self.input_channels != channels * block.expansion: down_sample = nn.SequentialCell([ - nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride), + mint.nn.Conv2d( + self.input_channels, channels * block.expansion, kernel_size=1, stride=stride, bias=False + ), self.norm(channels * block.expansion) ]) diff --git a/mindcv/models/squeezenet.py b/mindcv/models/squeezenet.py index b70a21644..f79b3dea4 100644 --- a/mindcv/models/squeezenet.py +++ b/mindcv/models/squeezenet.py @@ -4,10 +4,9 @@ """ import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Dropout from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -46,16 +45,16 @@ def __init__( ) -> None: super().__init__() self.squeeze = nn.Conv2d(in_channels, squeeze_channels, kernel_size=1, has_bias=True) - self.squeeze_activation = nn.ReLU() + self.squeeze_activation = mint.nn.ReLU() self.expand1x1 = nn.Conv2d(squeeze_channels, expand1x1_channels, kernel_size=1, has_bias=True) - self.expand1x1_activation = nn.ReLU() + self.expand1x1_activation = mint.nn.ReLU() self.expand3x3 = nn.Conv2d(squeeze_channels, expand3x3_channels, kernel_size=3, pad_mode="same", has_bias=True) - self.expand3x3_activation = nn.ReLU() + self.expand3x3_activation = mint.nn.ReLU() def construct(self, x: Tensor) -> Tensor: x = self.squeeze_activation(self.squeeze(x)) - return ops.concat((self.expand1x1_activation(self.expand1x1(x)), - self.expand3x3_activation(self.expand3x3(x))), axis=1) + return mint.concat((self.expand1x1_activation(self.expand1x1(x)), + self.expand3x3_activation(self.expand3x3(x))), dim=1) class SqueezeNet(nn.Cell): @@ -84,30 +83,30 @@ def __init__( if version == "1_0": self.features = nn.SequentialCell([ nn.Conv2d(in_channels, 96, kernel_size=7, stride=2, pad_mode="valid", has_bias=True), - nn.ReLU(), - nn.MaxPool2d(kernel_size=3, stride=2), + mint.nn.ReLU(), + mint.nn.MaxPool2d(kernel_size=3, stride=2), Fire(96, 16, 64, 64), Fire(128, 16, 64, 64), Fire(128, 32, 128, 128), - nn.MaxPool2d(kernel_size=3, stride=2), + mint.nn.MaxPool2d(kernel_size=3, stride=2), Fire(256, 32, 128, 128), Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), Fire(384, 64, 256, 256), - nn.MaxPool2d(kernel_size=3, stride=2), + mint.nn.MaxPool2d(kernel_size=3, stride=2), Fire(512, 64, 256, 256), ]) elif version == "1_1": self.features = nn.SequentialCell([ nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1, pad_mode="pad", has_bias=True), - nn.ReLU(), - nn.MaxPool2d(kernel_size=3, stride=2), + mint.nn.ReLU(), + mint.nn.MaxPool2d(kernel_size=3, stride=2), Fire(64, 16, 64, 64), Fire(128, 16, 64, 64), - nn.MaxPool2d(kernel_size=3, stride=2), + mint.nn.MaxPool2d(kernel_size=3, stride=2), Fire(128, 32, 128, 128), Fire(256, 32, 128, 128), - nn.MaxPool2d(kernel_size=3, stride=2), + mint.nn.MaxPool2d(kernel_size=3, stride=2), Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), Fire(384, 64, 256, 256), @@ -118,9 +117,9 @@ def __init__( self.final_conv = nn.Conv2d(512, num_classes, kernel_size=1, has_bias=True) self.classifier = nn.SequentialCell([ - Dropout(p=drop_rate), + mint.nn.Dropout(p=drop_rate), self.final_conv, - nn.ReLU(), + mint.nn.ReLU(), GlobalAvgPooling() ]) self._initialize_weights() diff --git a/mindcv/models/swintransformer.py b/mindcv/models/swintransformer.py index 237f1c7ee..6b0b65caa 100644 --- a/mindcv/models/swintransformer.py +++ b/mindcv/models/swintransformer.py @@ -4,13 +4,14 @@ import numpy as np import mindspore.common.initializer as init +import mindspore.mint.nn.functional as F from mindspore import Parameter, Tensor from mindspore import dtype as mstype -from mindspore import nn, numpy, ops +from mindspore import mint, nn, numpy from .helpers import _ntuple, load_pretrained from .layers import DropPath, Identity -from .layers.compatibility import Dropout +from .layers.extend_bmm import ExtendBatchMatMul from .registry import register_model __all__ = [ @@ -38,20 +39,20 @@ def _cfg(url="", **kwargs): class Mlp(nn.Cell): def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Optional[nn.Cell] = nn.GELU, - drop: float = 0.0, + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Optional[nn.Cell] = mint.nn.GELU, + drop: float = 0.0, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=True) + self.fc1 = mint.nn.Linear(in_features=in_features, out_features=hidden_features, bias=True) self.act = act_layer() - self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=True) - self.drop = Dropout(p=drop) + self.fc2 = mint.nn.Linear(in_features=hidden_features, out_features=out_features, bias=True) + self.drop = mint.nn.Dropout(p=drop) def construct(self, x: Tensor) -> Tensor: x = self.fc1(x) @@ -79,8 +80,8 @@ def window_partition(x, window_size: int): class WindowPartition(nn.Cell): def __init__( - self, - window_size: int, + self, + window_size: int, ) -> None: super(WindowPartition, self).__init__() @@ -96,20 +97,20 @@ def construct(self, x: Tensor) -> Tensor: windows: Tensor(num_windows*b, window_size, window_size, c) """ b, h, w, c = x.shape - x = ops.reshape(x, (b, h // self.window_size, self.window_size, w // self.window_size, self.window_size, c)) - x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) - x = ops.reshape(x, (b * h * w // (self.window_size**2), self.window_size, self.window_size, c)) + x = mint.reshape(x, (b, h // self.window_size, self.window_size, w // self.window_size, self.window_size, c)) + x = mint.permute(x, (0, 1, 3, 2, 4, 5)) + x = mint.reshape(x, (b * h * w // (self.window_size ** 2), self.window_size, self.window_size, c)) return x class WindowReverse(nn.Cell): def construct( - self, - windows: Tensor, - window_size: int, - h: int, - w: int, + self, + windows: Tensor, + window_size: int, + h: int, + w: int, ) -> Tensor: """ Args: @@ -122,17 +123,17 @@ def construct( x: (B, H, W, C) """ b = windows.shape[0] // (h * w // window_size // window_size) - x = ops.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1)) - x = ops.transpose(x, (0, 1, 3, 2, 4, 5)) - x = ops.reshape(x, (b, h, w, -1)) + x = mint.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1)) + x = mint.permute(x, (0, 1, 3, 2, 4, 5)) + x = mint.reshape(x, (b, h, w, -1)) return x class RelativeBias(nn.Cell): def __init__( - self, - window_size: int, - num_heads: int, + self, + window_size: int, + num_heads: int, ) -> None: super().__init__() self.window_size = window_size @@ -151,18 +152,16 @@ def __init__( self.relative_position_bias_table = Parameter( Tensor(np.random.randn((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads), dtype=mstype.float32)) # 2*Wh-1 * 2*Ww-1, nH - self.one_hot = ops.OneHot() - self.index = Parameter(self.one_hot(self.relative_position_index, - (2 * window_size[0] - 1) * (2 * window_size[1] - 1), - Tensor(1.0), Tensor(0.0)), + self.index = Parameter(F.one_hot(self.relative_position_index, + (2 * window_size[0] - 1) * (2 * window_size[1] - 1)), requires_grad=False) def construct(self) -> Tensor: - out = ops.matmul(self.index, self.relative_position_bias_table) - out = ops.reshape(out, (self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1)) - out = ops.transpose(out, (2, 0, 1)) - out = ops.expand_dims(out, 0) + out = mint.matmul(self.index.to(mstype.float32), self.relative_position_bias_table) + out = mint.reshape(out, (self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1)) + out = mint.permute(out, (2, 0, 1)) + out = mint.unsqueeze(out, 0) return out @@ -181,14 +180,14 @@ class WindowAttention(nn.Cell): """ def __init__( - self, - dim: int, - window_size: int, - num_heads: int, - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - attn_drop: float = 0.0, - proj_drop: float = 0.0, + self, + dim: int, + window_size: int, + num_heads: int, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, ) -> None: super().__init__() if isinstance(dim, tuple) and len(dim) == 1: @@ -197,19 +196,19 @@ def __init__( self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = Tensor(qk_scale or head_dim**-0.5, mstype.float32) + self.scale = Tensor(qk_scale or head_dim ** -0.5, mstype.float32) self.relative_bias = RelativeBias(self.window_size, num_heads) # get pair-wise relative position index for each token inside the window - self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias) - self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias) - self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias) + self.q = mint.nn.Linear(in_features=dim, out_features=dim, bias=qkv_bias) + self.k = mint.nn.Linear(in_features=dim, out_features=dim, bias=qkv_bias) + self.v = mint.nn.Linear(in_features=dim, out_features=dim, bias=qkv_bias) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(in_channels=dim, out_channels=dim, has_bias=True) - self.proj_drop = Dropout(p=proj_drop) - self.softmax = nn.Softmax(axis=-1) - self.batch_matmul = ops.BatchMatMul() + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(in_features=dim, out_features=dim, bias=True) + self.proj_drop = mint.nn.Dropout(p=proj_drop) + self.softmax = mint.nn.Softmax(dim=-1) + self.batch_matmul = ExtendBatchMatMul() def construct(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: """ @@ -218,25 +217,25 @@ def construct(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ b_, n, c = x.shape - q = ops.reshape(self.q(x), (b_, n, self.num_heads, c // self.num_heads)) * self.scale - q = ops.transpose(q, (0, 2, 1, 3)) - k = ops.reshape(self.k(x), (b_, n, self.num_heads, c // self.num_heads)) - k = ops.transpose(k, (0, 2, 3, 1)) - v = ops.reshape(self.v(x), (b_, n, self.num_heads, c // self.num_heads)) - v = ops.transpose(v, (0, 2, 1, 3)) + q = mint.reshape(self.q(x), (b_, n, self.num_heads, c // self.num_heads)) * self.scale + q = mint.permute(q, (0, 2, 1, 3)) + k = mint.reshape(self.k(x), (b_, n, self.num_heads, c // self.num_heads)) + k = mint.permute(k, (0, 2, 3, 1)) + v = mint.reshape(self.v(x), (b_, n, self.num_heads, c // self.num_heads)) + v = mint.permute(v, (0, 2, 1, 3)) attn = self.batch_matmul(q, k) attn = attn + self.relative_bias() if mask is not None: nw = mask.shape[1] - attn = ops.reshape(attn, (b_ // nw, nw, self.num_heads, n, n,)) + mask - attn = ops.reshape(attn, (-1, self.num_heads, n, n,)) + attn = mint.reshape(attn, (b_ // nw, nw, self.num_heads, n, n,)) + mask + attn = mint.reshape(attn, (-1, self.num_heads, n, n,)) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) - x = ops.reshape(ops.transpose(self.batch_matmul(attn, v), (0, 2, 1, 3)), (b_, n, c)) + x = mint.reshape(mint.permute(self.batch_matmul(attn, v), (0, 2, 1, 3)), (b_, n, c)) x = self.proj(x) x = self.proj_drop(x) return x @@ -265,20 +264,20 @@ class SwinTransformerBlock(nn.Cell): """ def __init__( - self, - dim: int, - input_resolution: Tuple[int], - num_heads: int, - window_size: int = 7, - shift_size: int = 0, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - drop: float = 0.0, - attn_drop: float = 0.0, - drop_path: float = 0.0, - act_layer: Optional[nn.Cell] = nn.GELU, - norm_layer: Optional[nn.Cell] = nn.LayerNorm, + self, + dim: int, + input_resolution: Tuple[int], + num_heads: int, + window_size: int = 7, + shift_size: int = 0, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: Optional[nn.Cell] = mint.nn.GELU, + norm_layer: Optional[nn.Cell] = mint.nn.LayerNorm, ) -> None: super(SwinTransformerBlock, self).__init__() self.dim = dim @@ -343,7 +342,7 @@ def construct(self, x: Tensor) -> Tensor: shortcut = x x = self.norm1(x) - x = ops.reshape(x, (b, h, w, c,)) + x = mint.reshape(x, (b, h, w, c,)) # cyclic shift if self.shift_size > 0: @@ -354,14 +353,14 @@ def construct(self, x: Tensor) -> Tensor: # partition windows x_windows = self.window_partition(shifted_x) # nW*B, window_size, window_size, C - x_windows = ops.reshape(x_windows, - (-1, self.window_size * self.window_size, c,)) # nW*B, window_size*window_size, C + x_windows = mint.reshape(x_windows, + (-1, self.window_size * self.window_size, c,)) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows - attn_windows = ops.reshape(attn_windows, (-1, self.window_size, self.window_size, c,)) + attn_windows = mint.reshape(attn_windows, (-1, self.window_size, self.window_size, c,)) shifted_x = self.window_reverse(attn_windows, self.window_size, h, w) # B H' W' C # reverse cyclic shift @@ -370,7 +369,7 @@ def construct(self, x: Tensor) -> Tensor: else: x = shifted_x - x = ops.reshape(x, (b, h * w, c,)) + x = mint.reshape(x, (b, h * w, c,)) # FFN x = shortcut + self.drop_path(x) @@ -386,9 +385,9 @@ def extra_repr(self) -> str: class Roll(nn.Cell): def __init__( - self, - shift_size: int, - shift_axis: Tuple[int] = (1, 2), + self, + shift_size: int, + shift_axis: Tuple[int] = (1, 2), ) -> None: super().__init__() self.shift_size = to_2tuple(shift_size) @@ -409,16 +408,16 @@ class PatchMerging(nn.Cell): """ def __init__( - self, - input_resolution: Tuple[int], - dim: int, - norm_layer: Optional[nn.Cell] = nn.LayerNorm, + self, + input_resolution: Tuple[int], + dim: int, + norm_layer: Optional[nn.Cell] = mint.nn.LayerNorm, ) -> None: super().__init__() self.input_resolution = input_resolution self.dim = dim[0] if isinstance(dim, tuple) and len(dim) == 1 else dim # Default False - self.reduction = nn.Dense(in_channels=4 * dim, out_channels=2 * dim, has_bias=False) + self.reduction = mint.nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer([dim * 4, ]) self.H, self.W = self.input_resolution self.H_2, self.W_2 = self.H // 2, self.W // 2 @@ -431,9 +430,9 @@ def construct(self, x: Tensor) -> Tensor: x: B, H*W, C """ b = x.shape[0] - x = ops.reshape(x, (b, self.H_2, 2, self.W_2, 2, self.dim)) - x = ops.transpose(x, (0, 1, 3, 4, 2, 5)) - x = ops.reshape(x, (b, self.H2W2, self.dim_mul_4)) + x = mint.reshape(x, (b, self.H_2, 2, self.W_2, 2, self.dim)) + x = mint.permute(x, (0, 1, 3, 4, 2, 5)) + x = mint.reshape(x, (b, self.H2W2, self.dim_mul_4)) x = self.norm(x) x = self.reduction(x) @@ -463,20 +462,20 @@ class BasicLayer(nn.Cell): """ def __init__( - self, - dim: int, - input_resolution: Tuple[int], - depth: int, - num_heads: int, - window_size: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - qk_scale: Optional[float] = None, - drop: float = 0.0, - attn_drop: float = 0.0, - drop_path: Optional[float] = 0.0, - norm_layer: Optional[nn.Cell] = nn.LayerNorm, - downsample: Optional[nn.Cell] = None, + self, + dim: int, + input_resolution: Tuple[int], + depth: int, + num_heads: int, + window_size: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[float] = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: Optional[float] = 0.0, + norm_layer: Optional[nn.Cell] = mint.nn.LayerNorm, + downsample: Optional[nn.Cell] = None, ) -> None: super().__init__() self.dim = dim @@ -524,12 +523,12 @@ class PatchEmbed(nn.Cell): """ def __init__( - self, - image_size: int = 224, - patch_size: int = 4, - in_chans: int = 3, - embed_dim: int = 96, - norm_layer: Optional[nn.Cell] = None, + self, + image_size: int = 224, + patch_size: int = 4, + in_chans: int = 3, + embed_dim: int = 96, + norm_layer: Optional[nn.Cell] = None, ) -> None: super().__init__() image_size = to_2tuple(image_size) @@ -543,8 +542,9 @@ def __init__( self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - pad_mode="pad", has_bias=True, weight_init="TruncatedNormal") + self.proj = mint.nn.Conv2d( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=True + ) if norm_layer is not None: if isinstance(embed_dim, int): @@ -556,8 +556,8 @@ def __init__( def construct(self, x: Tensor) -> Tensor: b = x.shape[0] # FIXME look at relaxing size constraints - x = ops.reshape(self.proj(x), (b, self.embed_dim, -1)) # b Ph*Pw c - x = ops.transpose(x, (0, 2, 1)) + x = mint.reshape(self.proj(x), (b, self.embed_dim, -1)) # b Ph*Pw c + x = mint.permute(x, (0, 2, 1)) if self.norm is not None: x = self.norm(x) @@ -589,24 +589,24 @@ class SwinTransformer(nn.Cell): """ def __init__( - self, - image_size: int = 224, - patch_size: int = 4, - in_chans: int = 3, - num_classes: int = 1000, - embed_dim: int = 96, - depths: Optional[List[int]] = None, - num_heads: Optional[List[int]] = None, - window_size: int = 7, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - qk_scale: Optional[int] = None, - drop_rate: float = 0.0, - attn_drop_rate: float = 0.0, - drop_path_rate: float = 0.1, - norm_layer: Optional[nn.Cell] = nn.LayerNorm, - ape: bool = False, - patch_norm: bool = True, + self, + image_size: int = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: int = 96, + depths: Optional[List[int]] = None, + num_heads: Optional[List[int]] = None, + window_size: int = 7, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: Optional[int] = None, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.1, + norm_layer: Optional[nn.Cell] = mint.nn.LayerNorm, + ape: bool = False, + patch_norm: bool = True, ) -> None: super().__init__() @@ -630,7 +630,7 @@ def __init__( if self.ape: self.absolute_pos_embed = Parameter(Tensor(np.zeros(1, num_patches, embed_dim), dtype=mstype.float32)) - self.pos_drop = Dropout(p=drop_rate) + self.pos_drop = mint.nn.Dropout(p=drop_rate) # stochastic depth dpr = [x for x in np.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule @@ -653,21 +653,21 @@ def __init__( self.layers.append(layer) self.norm = norm_layer([self.num_features, ], epsilon=1e-5) - self.classifier = nn.Dense(in_channels=self.num_features, - out_channels=num_classes, has_bias=True) if num_classes > 0 else Identity() + self.classifier = mint.nn.Linear(in_features=self.num_features, + out_features=num_classes, bias=True) if num_classes > 0 else Identity() self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Dense): + if isinstance(cell, (mint.nn.Linear, mint.nn.Conv2d)): cell.weight.set_data(init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype)) - if isinstance(cell, nn.Dense) and cell.bias is not None: + if isinstance(cell, mint.nn.Linear) and cell.bias is not None: cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.LayerNorm): - cell.gamma.set_data(init.initializer(init.One(), cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer(init.Zero(), cell.beta.shape, cell.beta.dtype)) + elif isinstance(cell, mint.nn.LayerNorm): + cell.weight.set_data(init.initializer(init.One(), cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer(init.Zero(), cell.bias.shape, cell.bias.dtype)) def no_weight_decay(self) -> None: return {"absolute_pos_embed"} @@ -687,7 +687,7 @@ def forward_features(self, x: Tensor) -> Tensor: for layer in self.layers: x = layer(x) x = self.norm(x) # B L C - x = ops.mean(ops.transpose(x, (0, 2, 1)), 2) # B C 1 + x = mint.mean(mint.permute(x, (0, 2, 1)), 2) # B C 1 return x def construct(self, x: Tensor) -> Tensor: diff --git a/mindcv/models/swintransformerv2.py b/mindcv/models/swintransformerv2.py index 649e2f780..52fd9d01f 100644 --- a/mindcv/models/swintransformerv2.py +++ b/mindcv/models/swintransformerv2.py @@ -11,11 +11,11 @@ import mindspore.common.initializer as init from mindspore import Parameter, Tensor from mindspore import dtype as mstype -from mindspore import nn, ops +from mindspore import mint, nn, ops from .helpers import _ntuple, load_pretrained -from .layers import DropPath, Identity -from .layers.compatibility import Dropout +from .layers import DropPath, Identity, Sigmoid +from .layers.extend_bmm import ExtendBatchMatMul from .registry import register_model __all__ = [ @@ -72,9 +72,9 @@ def __init__(self, window_size: int) -> None: def construct(self, x: Tensor) -> Tensor: b, h, w, c = x.shape - x = x.reshape(b, h // self.window_size, self.window_size, w // self.window_size, self.window_size, c) - x = x.transpose(0, 1, 3, 2, 4, 5) - x = x.reshape(b * h * w // (self.window_size**2), self.window_size, self.window_size, c) + x = mint.reshape(x, (b, h // self.window_size, self.window_size, w // self.window_size, self.window_size, c)) + x = mint.permute(x, (0, 1, 3, 2, 4, 5)) + x = mint.reshape(x, (b * h * w // (self.window_size**2), self.window_size, self.window_size, c)) return x @@ -92,9 +92,9 @@ def __init__(self) -> None: def construct(self, windows: Tensor, window_size: int, h: int, w: int) -> Tensor: b = windows.shape[0] // (h * w // window_size // window_size) - x = windows.reshape(b, h // window_size, w // window_size, window_size, window_size, -1) - x = x.transpose(0, 1, 3, 2, 4, 5) - x = x.reshape(b, h, w, -1) + x = mint.reshape(windows, (b, h // window_size, w // window_size, window_size, window_size, -1)) + x = mint.permute(x, (0, 1, 3, 2, 4, 5)) + x = mint.reshape(x, (b, h, w, -1)) return x @@ -109,9 +109,9 @@ def __init__( self.window_size = window_size # Wh, Ww # mlp to generate continuous relative position bias self.num_heads = num_heads - self.cpb_mlp0 = nn.Dense(2, 512, has_bias=True) - self.cpb_act1 = nn.ReLU() - self.cpb_mlp2 = nn.Dense(512, num_heads, has_bias=False) + self.cpb_mlp0 = mint.nn.Linear(2, 512, bias=True) + self.cpb_act1 = mint.nn.ReLU() + self.cpb_mlp2 = mint.nn.Linear(512, num_heads, bias=False) relative_coords_h = np.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=float) relative_coords_w = np.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=float) @@ -151,20 +151,21 @@ def __init__( Tensor(relative_position_index, mstype.int32), requires_grad=False ) - self.sigmoid = ops.Sigmoid() + # todo + self.sigmoid = Sigmoid() def construct(self) -> Tensor: x = self.cpb_mlp0(self.relative_coords_table) x = self.cpb_act1(x) x = self.cpb_mlp2(x) x = x.reshape(-1, self.num_heads) - relative_position_bias = x[ops.reshape(self.relative_position_index, (-1,))] - relative_position_bias = ops.reshape(relative_position_bias, ( + relative_position_bias = x[mint.reshape(self.relative_position_index, (-1,))] + relative_position_bias = mint.reshape(relative_position_bias, ( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)) relative_position_bias = relative_position_bias.transpose(2, 0, 1) relative_position_bias = 16 * self.sigmoid(relative_position_bias) - relative_position_bias = ops.expand_dims(relative_position_bias, axis=0) + relative_position_bias = mint.unsqueeze(relative_position_bias, dim=0) return relative_position_bias @@ -186,47 +187,48 @@ def __init__( self.window_size = window_size # Wh, Ww self.num_heads = num_heads - self.matmul = ops.BatchMatMul() + self.matmul = ExtendBatchMatMul() - logit_scale = Tensor((ops.log(10 * ops.ones((num_heads, 1, 1), mstype.float32)))) + logit_scale = Tensor((mint.log(10 * mint.ones((num_heads, 1, 1), dtype=mstype.float32)))) self.logit_scale = Parameter(logit_scale, requires_grad=True) max = Tensor(100, mstype.float32) - self.value_max = ops.log(max) + self.value_max = mint.log(max) self.value_min = Tensor((-1000), mstype.float32) - self.q = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias) - self.k = nn.Dense(in_channels=dim, out_channels=dim, has_bias=False) - self.v = nn.Dense(in_channels=dim, out_channels=dim, has_bias=qkv_bias) + self.q = mint.nn.Linear(in_features=dim, out_features=dim, bias=qkv_bias) + self.k = mint.nn.Linear(in_features=dim, out_features=dim, bias=False) + self.v = mint.nn.Linear(in_features=dim, out_features=dim, bias=qkv_bias) + # todo self.Normalize = ops.L2Normalize(axis=-1, epsilon=1e-12) # get pair-wise relative position index for each token inside the window self.relative_position_bias = LogSpacedCPB(self.window_size, num_heads, pretrained_window_size) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(in_channels=dim, out_channels=dim, has_bias=True) - self.softmax = nn.Softmax(axis=-1) - self.proj_drop = Dropout(p=proj_drop) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(in_features=dim, out_features=dim, bias=True) + self.softmax = mint.nn.Softmax(dim=-1) + self.proj_drop = mint.nn.Dropout(p=proj_drop) def construct(self, x: Tensor, mask=None) -> Tensor: B_, N, C = x.shape - q = ops.reshape(self.q(x), (B_, N, self.num_heads, C // self.num_heads)) + q = mint.reshape(self.q(x), (B_, N, self.num_heads, C // self.num_heads)) q = self.Normalize(q) - q = ops.transpose(q, (0, 2, 1, 3)) + q = mint.permute(q, (0, 2, 1, 3)) - k = ops.reshape(self.k(x), (B_, N, self.num_heads, C // self.num_heads)) + k = mint.reshape(self.k(x), (B_, N, self.num_heads, C // self.num_heads)) k = self.Normalize(k) - k = ops.transpose(k, (0, 2, 3, 1)) + k = mint.permute(k, (0, 2, 3, 1)) - v = ops.reshape(self.v(x), (B_, N, self.num_heads, C // self.num_heads)) - v = ops.transpose(v, (0, 2, 1, 3)) + v = mint.reshape(self.v(x), (B_, N, self.num_heads, C // self.num_heads)) + v = mint.permute(v, (0, 2, 1, 3)) attn = self.matmul(q, k) - logit_scale = ops.clip_by_value(self.logit_scale, clip_value_min=self.value_min, clip_value_max=self.value_max) - logit_scale = ops.exp(logit_scale) + logit_scale = mint.clamp(self.logit_scale, min=self.value_min, max=self.value_max) + logit_scale = mint.exp(logit_scale) attn = attn * logit_scale @@ -234,16 +236,16 @@ def construct(self, x: Tensor, mask=None) -> Tensor: if mask is not None: nW, ws2, _ = mask.shape - mask = ops.reshape(mask, (1, -1, 1, ws2, ws2)) - attn = ops.reshape(attn, (B_ // nW, nW, self.num_heads, N, N,)) + mask - attn = ops.reshape(attn, (-1, self.num_heads, N, N,)) + mask = mint.reshape(mask, (1, -1, 1, ws2, ws2)) + attn = mint.reshape(attn, (B_ // nW, nW, self.num_heads, N, N,)) + mask + attn = mint.reshape(attn, (-1, self.num_heads, N, N,)) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) - x = ops.reshape(ops.transpose(self.matmul(attn, v), (0, 2, 1, 3)), (B_, N, C)) + x = mint.reshape(mint.permute(self.matmul(attn, v), (0, 2, 1, 3)), (B_, N, C)) x = self.proj(x) x = self.proj_drop(x) @@ -256,16 +258,16 @@ def __init__( in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, - act_layer: Optional[nn.Cell] = nn.GELU, + act_layer: Optional[nn.Cell] = mint.nn.GELU, drop: float = 0.0, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Dense(in_channels=in_features, out_channels=hidden_features, has_bias=True) + self.fc1 = mint.nn.Linear(in_features=in_features, out_features=hidden_features, bias=True) self.act = act_layer() - self.fc2 = nn.Dense(in_channels=hidden_features, out_channels=out_features, has_bias=True) - self.drop = Dropout(p=drop) + self.fc2 = mint.nn.Linear(in_features=hidden_features, out_features=out_features, bias=True) + self.drop = mint.nn.Dropout(p=drop) def construct(self, x: Tensor) -> Tensor: x = self.fc1(x) @@ -289,8 +291,8 @@ def __init__( drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, + act_layer: nn.Cell = mint.nn.GELU, + norm_layer: nn.Cell = mint.nn.LayerNorm, pretrained_window_size: int = 0, ) -> None: super(SwinTransformerBlock, self).__init__() @@ -308,7 +310,7 @@ def __init__( if isinstance(dim, int): dim = (dim,) - self.norm1 = norm_layer(dim, epsilon=1e-6) + self.norm1 = norm_layer(dim, eps=1e-6) self.attn = WindowCosineAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, @@ -316,7 +318,7 @@ def __init__( pretrained_window_size=to_2tuple(pretrained_window_size)) self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() - self.norm2 = norm_layer(dim, epsilon=1e-6) + self.norm2 = norm_layer(dim, eps=1e-6) mlp_hidden_dim = int((dim[0] if isinstance(dim, tuple) else dim) * mlp_ratio) self.mlp = Mlp(in_features=dim[0] if isinstance(dim, tuple) else dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) @@ -360,7 +362,7 @@ def construct(self, x: Tensor) -> Tensor: shortcut = x # x = self.norm1(x) - x = ops.reshape(x, (B, H, W, C,)) + x = mint.reshape(x, (B, H, W, C,)) # cyclic shift if self.shift_size > 0: @@ -372,13 +374,13 @@ def construct(self, x: Tensor) -> Tensor: # nW*B, window_size, window_size, C x_windows = self.window_partition(shifted_x) # nW*B, window_size*window_size, C - x_windows = ops.reshape(x_windows, (-1, self.window_size * self.window_size, C,)) + x_windows = mint.reshape(x_windows, (-1, self.window_size * self.window_size, C,)) # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows - attn_windows = ops.reshape(attn_windows, (-1, self.window_size, self.window_size, C,)) + attn_windows = mint.reshape(attn_windows, (-1, self.window_size, self.window_size, C,)) shifted_x = self.window_reverse(attn_windows, self.window_size, H, W) # B H' W' C # reverse cyclic shift @@ -387,7 +389,7 @@ def construct(self, x: Tensor) -> Tensor: else: x = shifted_x - x = ops.reshape(x, (B, H * W, C,)) + x = mint.reshape(x, (B, H * W, C,)) # FFN post-res-norm x = shortcut + self.drop_path(self.norm1(x)) @@ -402,14 +404,14 @@ def __init__( self, input_resolution: Tuple[int, int], dim: int, - norm_layer: nn.Cell = nn.LayerNorm, + norm_layer: nn.Cell = mint.nn.LayerNorm, ) -> None: super().__init__() self.input_resolution = input_resolution self.dim = dim[0] if isinstance(dim, tuple) and len(dim) == 1 else dim # Default False - self.reduction = nn.Dense(in_channels=4 * dim, out_channels=2 * dim, has_bias=False) - self.norm = norm_layer([dim * 2], epsilon=1e-5) + self.reduction = mint.nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False) + self.norm = norm_layer([dim * 2], eps=1e-5) self.H, self.W = self.input_resolution self.H_2, self.W_2 = self.H // 2, self.W // 2 self.H2W2 = int(self.H * self.W // 4) @@ -418,9 +420,9 @@ def __init__( def construct(self, x: Tensor) -> Tensor: B = x.shape[0] - x = ops.reshape(x, (B, self.H_2, 2, self.W_2, 2, self.dim)) - x = ops.transpose(x, (0, 1, 3, 4, 2, 5)) - x = ops.reshape(x, (B, self.H2W2, self.dim_mul_4)) + x = mint.reshape(x, (B, self.H_2, 2, self.W_2, 2, self.dim)) + x = mint.permute(x, (0, 1, 3, 4, 2, 5)) + x = mint.reshape(x, (B, self.H2W2, self.dim_mul_4)) x = self.reduction(x) x = self.norm(x) @@ -440,7 +442,7 @@ def __init__( drop: float = 0.0, attn_drop: float = 0.0, drop_path: List[float] = 0.0, - norm_layer: nn.Cell = nn.LayerNorm, + norm_layer: nn.Cell = mint.nn.LayerNorm, downsample: Optional[nn.Cell] = None, pretrained_window_size: int = 0, ) -> None: @@ -497,20 +499,21 @@ def __init__( self.in_chans = in_chans self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, - pad_mode='pad', has_bias=True) + self.proj = mint.nn.Conv2d( + in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=True + ) if norm_layer is not None: if isinstance(embed_dim, int): embed_dim = (embed_dim,) - self.norm = norm_layer(embed_dim, epsilon=1e-6) + self.norm = norm_layer(embed_dim, eps=1e-6) else: self.norm = None def construct(self, x: Tensor) -> Tensor: B = x.shape[0] - x = ops.reshape(self.proj(x), (B, self.embed_dim, -1)) - x = ops.transpose(x, (0, 2, 1)) # B Ph*Pw C + x = mint.reshape(self.proj(x), (B, self.embed_dim, -1)) + x = mint.permute(x, (0, 2, 1)) # B Ph*Pw C if self.norm is not None: x = self.norm(x) @@ -555,7 +558,7 @@ def __init__( drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.1, - norm_layer: nn.Cell = nn.LayerNorm, + norm_layer: nn.Cell = mint.nn.LayerNorm, patch_norm: bool = True, pretrained_window_sizes: List[int] = [0, 0, 0, 0], ) -> None: @@ -578,7 +581,7 @@ def __init__( patches_resolution = self.patch_embed.patches_resolution self.patches_resolution = patches_resolution - self.pos_drop = Dropout(p=drop_rate) + self.pos_drop = mint.nn.Dropout(p=drop_rate) # stochastic depth dpr = [x for x in np.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule @@ -606,23 +609,21 @@ def __init__( if i_layer < self.num_layers - 1: self.final_seq = self.final_seq // 4 self.layers.append(layer) - self.head = nn.Dense(self.num_features, self.num_classes) - - self.norm = norm_layer([self.num_features, ], epsilon=1e-6) - self.avgpool = ops.ReduceMean(keep_dims=False) + self.head = mint.nn.Linear(self.num_features, self.num_classes) + self.norm = norm_layer([self.num_features, ], eps=1e-6) self._initialize_weights() def _initialize_weights(self): for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data(init.initializer(init.HeUniform(), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.LayerNorm): - cell.gamma.set_data(init.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) - cell.beta.set_data(init.initializer("zeros", cell.beta.shape, cell.beta.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.LayerNorm): + cell.weight.set_data(init.initializer("ones", cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype) ) @@ -635,7 +636,7 @@ def forward_features(self, x: Tensor) -> Tensor: for layer in self.layers: x = layer(x) x = self.norm(x) # B L C - x = self.avgpool(ops.transpose(x, (0, 2, 1)), 2) # B C 1 + x = mint.mean(mint.permute(x, (0, 2, 1)), 2) # B C 1 return x def forward_head(self, x: Tensor) -> Tensor: diff --git a/mindcv/models/vgg.py b/mindcv/models/vgg.py index 8c37d8596..681f13617 100644 --- a/mindcv/models/vgg.py +++ b/mindcv/models/vgg.py @@ -7,10 +7,10 @@ from typing import Dict, List, Union import mindspore.common.initializer as init -from mindspore import Tensor, nn +from mindspore import Tensor, mint, nn from .helpers import load_pretrained -from .layers.compatibility import Dropout +from .layers import Flatten from .registry import register_model __all__ = [ @@ -57,13 +57,13 @@ def _make_layers( layers = [] for v in cfg: if v == "M": - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + layers += [mint.nn.MaxPool2d(kernel_size=2, stride=2)] else: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, pad_mode="pad", padding=1) + conv2d = mint.nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) if batch_norm: - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] + layers += [conv2d, mint.nn.BatchNorm2d(v), mint.nn.ReLU()] else: - layers += [conv2d, nn.ReLU()] + layers += [conv2d, mint.nn.ReLU()] in_channels = v return nn.SequentialCell(layers) @@ -92,29 +92,29 @@ def __init__( super().__init__() cfg = cfgs[model_name] self.features = _make_layers(cfg, batch_norm=batch_norm, in_channels=in_channels) - self.flatten = nn.Flatten() + self.flatten = Flatten() self.classifier = nn.SequentialCell([ - nn.Dense(512 * 7 * 7, 4096), - nn.ReLU(), - Dropout(p=drop_rate), - nn.Dense(4096, 4096), - nn.ReLU(), - Dropout(p=drop_rate), - nn.Dense(4096, num_classes), + mint.nn.Linear(512 * 7 * 7, 4096), + mint.nn.ReLU(), + mint.nn.Dropout(p=drop_rate), + mint.nn.Linear(4096, 4096), + mint.nn.ReLU(), + mint.nn.Dropout(p=drop_rate), + mint.nn.Linear(4096, num_classes), ]) self._initialize_weights() def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data( init.initializer(init.HeNormal(math.sqrt(5), mode="fan_out", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data( init.initializer("zeros", cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.Dense): + elif isinstance(cell, mint.nn.Linear): cell.weight.set_data( init.initializer(init.Normal(0.01), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: diff --git a/mindcv/models/visformer.py b/mindcv/models/visformer.py index 1e120d403..a14a59664 100644 --- a/mindcv/models/visformer.py +++ b/mindcv/models/visformer.py @@ -8,12 +8,11 @@ import numpy as np import mindspore -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from mindspore.common.initializer import Constant, HeNormal, TruncatedNormal, initializer from .helpers import _ntuple, load_pretrained from .layers import DropPath, GlobalAvgPooling, Identity -from .layers.compatibility import Dropout from .registry import register_model __all__ = [ @@ -56,7 +55,7 @@ def __init__( in_features: int, hidden_features: int = None, out_features: int = None, - act_layer: nn.Cell = nn.GELU, + act_layer: nn.Cell = mint.nn.GELU, drop: float = 0.0, group: int = 8, spatial_conv: bool = False, @@ -74,13 +73,15 @@ def __init__( hidden_features = in_features * 2 self.hidden_features = hidden_features self.group = group - self.drop = Dropout(p=drop) - self.conv1 = nn.Conv2d(in_features, hidden_features, 1, 1, pad_mode="pad", padding=0) + self.drop = mint.nn.Dropout(p=drop) + self.conv1 = mint.nn.Conv2d(in_features, hidden_features, 1, 1, padding=0, bias=False) self.act1 = act_layer() if self.spatial_conv: - self.conv2 = nn.Conv2d(hidden_features, hidden_features, 3, 1, pad_mode="pad", padding=1, group=self.group) + self.conv2 = mint.nn.Conv2d( + hidden_features, hidden_features, 3, 1, padding=1, groups=self.group, bias=False + ) self.act2 = act_layer() - self.conv3 = nn.Conv2d(hidden_features, out_features, 1, 1, pad_mode="pad", padding=0) + self.conv3 = mint.nn.Conv2d(hidden_features, out_features, 1, 1, padding=0, bias=False) def construct(self, x: Tensor) -> Tensor: x = self.conv1(x) @@ -118,23 +119,24 @@ def __init__( qk_scale_factor = qk_scale if qk_scale is not None else -0.25 self.scale = head_dim**qk_scale_factor - self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, 1, pad_mode="pad", padding=0, has_bias=qkv_bias) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, 1, pad_mode="pad", padding=0) - self.proj_drop = Dropout(p=proj_drop) + self.qkv = mint.nn.Conv2d(dim, head_dim * num_heads * 3, 1, 1, padding=0, bias=qkv_bias) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Conv2d(self.head_dim * self.num_heads, dim, 1, 1, padding=0, bias=False) + self.proj_drop = mint.nn.Dropout(p=proj_drop) def construct(self, x: Tensor) -> Tensor: B, C, H, W = x.shape x = self.qkv(x) - qkv = ops.reshape(x, (B, 3, self.num_heads, self.head_dim, H * W)) - qkv = qkv.transpose((1, 0, 2, 4, 3)) + qkv = mint.reshape(x, (B, 3, self.num_heads, self.head_dim, H * W)) + qkv = mint.permute(qkv, (1, 0, 2, 4, 3)) q, k, v = qkv[0], qkv[1], qkv[2] - attn = ops.matmul(q * self.scale, k.transpose(0, 1, 3, 2) * self.scale) - attn = ops.Softmax(axis=-1)(attn) + attn = mint.matmul(q * self.scale, mint.permute(k, (0, 1, 3, 2)) * self.scale) + attn = mint.nn.Softmax(dim=-1)(attn) attn = self.attn_drop(attn) - x = ops.matmul(attn, v) + x = mint.matmul(attn, v) - x = x.transpose((0, 1, 3, 2)).reshape((B, -1, H, W)) + x = mint.permute(x, (0, 1, 3, 2)) + x = mint.reshape(x, (B, -1, H, W)) x = self.proj(x) x = self.proj_drop(x) @@ -155,7 +157,7 @@ def __init__( drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, - act_layer: nn.Cell = nn.GELU, + act_layer: nn.Cell = mint.nn.GELU, group: int = 8, attn_disabled: bool = False, spatial_conv: bool = False, @@ -165,11 +167,11 @@ def __init__( self.spatial_conv = spatial_conv self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity() if not attn_disabled: - self.norm1 = nn.BatchNorm2d(dim) + self.norm1 = mint.nn.BatchNorm2d(dim) self.attn = Attention(dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - self.norm2 = nn.BatchNorm2d(dim) + self.norm2 = mint.nn.BatchNorm2d(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, group=group, spatial_conv=spatial_conv) @@ -197,9 +199,10 @@ def __init__( self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, pad_mode="pad", padding=0, - has_bias=True) - self.norm = nn.BatchNorm2d(embed_dim) + self.proj = mint.nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0, bias=True + ) + self.norm = mint.nn.BatchNorm2d(embed_dim) def construct(self, x: Tensor) -> Tensor: x = self.proj(x) @@ -270,13 +273,13 @@ def __init__( dpr = np.linspace(0, drop_path_rate, sum(depth)).tolist() self.stem = nn.SequentialCell([ - nn.Conv2d(3, self.init_channels, 7, 2, pad_mode="pad", padding=3), - nn.BatchNorm2d(self.init_channels), - nn.ReLU() + mint.nn.Conv2d(3, self.init_channels, 7, 2, padding=3, bias=False), + mint.nn.BatchNorm2d(self.init_channels), + mint.nn.ReLU() ]) img_size //= 2 - self.pos_drop = Dropout(p=drop_rate) + self.pos_drop = mint.nn.Dropout(p=drop_rate) # stage0 if depth[0]: self.patch_embed0 = PatchEmbed(img_size=img_size, patch_size=2, in_chans=self.init_channels, @@ -284,7 +287,7 @@ def __init__( img_size //= 2 if self.pos_embed: self.pos_embed0 = mindspore.Parameter( - ops.zeros((1, embed_dim // 4, img_size, img_size), mindspore.float32)) + mint.zeros((1, embed_dim // 4, img_size, img_size), dtype=mindspore.float32)) self.stage0 = nn.CellList([ Block(dim=embed_dim // 4, num_heads=num_heads[0], head_dim_ratio=0.25, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], @@ -303,7 +306,9 @@ def __init__( img_size //= 4 if self.pos_embed: - self.pos_embed1 = mindspore.Parameter(ops.zeros((1, embed_dim // 2, img_size, img_size), mindspore.float32)) + self.pos_embed1 = mindspore.Parameter( + mint.zeros((1, embed_dim // 2, img_size, img_size), dtype=mindspore.float32) + ) self.stage1 = nn.CellList([ Block( @@ -318,7 +323,9 @@ def __init__( self.patch_embed2 = PatchEmbed(img_size=img_size, patch_size=2, in_chans=embed_dim // 2, embed_dim=embed_dim) img_size //= 2 if self.pos_embed: - self.pos_embed2 = mindspore.Parameter(ops.zeros((1, embed_dim, img_size, img_size), mindspore.float32)) + self.pos_embed2 = mindspore.Parameter( + mint.zeros((1, embed_dim, img_size, img_size), dtype=mindspore.float32) + ) self.stage2 = nn.CellList([ Block( dim=embed_dim, num_heads=num_heads[2], head_dim_ratio=1.0, mlp_ratio=mlp_ratio, @@ -332,7 +339,8 @@ def __init__( self.patch_embed3 = PatchEmbed(img_size=img_size, patch_size=2, in_chans=embed_dim, embed_dim=embed_dim * 2) img_size //= 2 if self.pos_embed: - self.pos_embed3 = mindspore.Parameter(ops.zeros((1, embed_dim * 2, img_size, img_size), mindspore.float32)) + self.pos_embed3 = mindspore.Parameter( + mint.zeros((1, embed_dim * 2, img_size, img_size), dtype=mindspore.float32)) self.stage3 = nn.CellList([ Block( dim=embed_dim * 2, num_heads=num_heads[3], head_dim_ratio=1.0, mlp_ratio=mlp_ratio, @@ -346,8 +354,8 @@ def __init__( if self.pool: self.global_pooling = GlobalAvgPooling() - self.norm = nn.BatchNorm2d(embed_dim * 2) - self.head = nn.Dense(embed_dim * 2, num_classes) + self.norm = mint.nn.BatchNorm2d(embed_dim * 2) + self.head = mint.nn.Linear(embed_dim * 2, num_classes) # weight init if self.pos_embed: @@ -364,17 +372,17 @@ def __init__( def _initialize_weights(self) -> None: for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Dense): + if isinstance(cell, mint.nn.Linear): cell.weight.set_data(initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: cell.bias.set_data(initializer(Constant(0), cell.bias.shape, cell.bias.dtype)) - elif isinstance(cell, nn.LayerNorm): - cell.beta.set_data(initializer(Constant(0), cell.beta.shape, cell.beta.dtype)) - cell.gamma.set_data(initializer(Constant(1), cell.gamma.shape, cell.gamma.dtype)) - elif isinstance(cell, nn.BatchNorm2d): - cell.beta.set_data(initializer(Constant(0), cell.beta.shape, cell.beta.dtype)) - cell.gamma.set_data(initializer(Constant(1), cell.gamma.shape, cell.gamma.dtype)) - elif isinstance(cell, nn.Conv2d): + elif isinstance(cell, mint.nn.LayerNorm): + cell.bias.set_data(initializer(Constant(0), cell.bias.shape, cell.bias.dtype)) + cell.weight.set_data(initializer(Constant(1), cell.weight.shape, cell.weight.dtype)) + elif isinstance(cell, mint.nn.BatchNorm2d): + cell.bias.set_data(initializer(Constant(0), cell.bias.shape, cell.bias.dtype)) + cell.weight.set_data(initializer(Constant(1), cell.weight.shape, cell.weight.dtype)) + elif isinstance(cell, mint.nn.Conv2d): if self.conv_init: cell.weight.set_data(initializer(HeNormal(mode="fan_out", nonlinearity="relu"), cell.weight.shape, cell.weight.dtype)) diff --git a/mindcv/models/vit.py b/mindcv/models/vit.py index c8ee0967e..e77103c1c 100644 --- a/mindcv/models/vit.py +++ b/mindcv/models/vit.py @@ -5,12 +5,12 @@ import numpy as np import mindspore as ms -from mindspore import Parameter, Tensor, nn, ops +from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common.initializer import TruncatedNormal, XavierUniform, initializer from .helpers import load_pretrained -from .layers.compatibility import Dropout from .layers.drop_path import DropPath +from .layers.extend_bmm import ExtendBatchMatMul from .layers.mlp import Mlp from .layers.patch_dropout import PatchDropout from .layers.patch_embed import PatchEmbed @@ -84,7 +84,7 @@ def __init__( qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, - norm_layer: nn.Cell = nn.LayerNorm, + norm_layer: nn.Cell = mint.nn.LayerNorm, ): super(Attention, self).__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -92,39 +92,37 @@ def __init__( self.head_dim = dim // num_heads self.scale = Tensor(self.head_dim ** -0.5) - self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) - self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() - self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity() + self.qkv = mint.nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer((self.head_dim,)) if qk_norm else mint.nn.Identity() + self.k_norm = norm_layer((self.head_dim,)) if qk_norm else mint.nn.Identity() - self.attn_drop = Dropout(attn_drop) - self.proj = nn.Dense(dim, dim) - self.proj_drop = Dropout(proj_drop) + self.attn_drop = mint.nn.Dropout(attn_drop) + self.proj = mint.nn.Linear(dim, dim) + self.proj_drop = mint.nn.Dropout(proj_drop) - self.mul = ops.Mul() - self.reshape = ops.Reshape() - self.transpose = ops.Transpose() self.unstack = ops.Unstack(axis=0) - self.attn_matmul_v = ops.BatchMatMul() - self.q_matmul_k = ops.BatchMatMul(transpose_b=True) + self.attn_matmul_v = ExtendBatchMatMul() + self.q_matmul_k = ExtendBatchMatMul(transpose_b=True) + self.softmax = mint.nn.Softmax(dim=-1) def construct(self, x): b, n, c = x.shape qkv = self.qkv(x) - qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim)) - qkv = self.transpose(qkv, (2, 0, 3, 1, 4)) + qkv = mint.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim)) + qkv = mint.permute(qkv, (2, 0, 3, 1, 4)) q, k, v = self.unstack(qkv) q, k = self.q_norm(q), self.k_norm(k) - q = self.mul(q, self.scale**0.5) - k = self.mul(k, self.scale**0.5) + q = mint.mul(q, self.scale**0.5) + k = mint.mul(k, self.scale**0.5) attn = self.q_matmul_k(q, k) - attn = ops.softmax(attn.astype(ms.float32), axis=-1).astype(attn.dtype) + attn = self.softmax(attn.astype(ms.float32)).astype(attn.dtype) attn = self.attn_drop(attn) out = self.attn_matmul_v(attn, v) - out = self.transpose(out, (0, 2, 1, 3)) - out = self.reshape(out, (b, n, c)) + out = mint.permute(out, (0, 2, 1, 3)) + out = mint.reshape(out, (b, n, c)) out = self.proj(out) out = self.proj_drop(out) @@ -192,8 +190,8 @@ def __init__( attn_drop: float = 0., init_values: Optional[float] = None, drop_path: float = 0., - act_layer: nn.Cell = nn.GELU, - norm_layer: nn.Cell = nn.LayerNorm, + act_layer: nn.Cell = mint.nn.GELU, + norm_layer: nn.Cell = mint.nn.LayerNorm, mlp_layer: Callable = Mlp, ): super(Block, self).__init__() @@ -207,8 +205,8 @@ def __init__( proj_drop=proj_drop, norm_layer=norm_layer, ) - self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else mint.nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else mint.nn.Identity() self.norm2 = norm_layer((dim,)) self.mlp = mlp_layer( @@ -217,8 +215,8 @@ def __init__( act_layer=act_layer, drop=proj_drop ) - self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else mint.nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else mint.nn.Identity() def construct(self, x): x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) @@ -255,9 +253,9 @@ def __init__( fc_norm: Optional[bool] = None, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, - act_layer: nn.Cell = nn.GELU, + act_layer: nn.Cell = mint.nn.GELU, embed_layer: Callable = PatchEmbed, - norm_layer: nn.Cell = nn.LayerNorm, + norm_layer: nn.Cell = mint.nn.LayerNorm, mlp_layer: Callable = Mlp, class_token: bool = True, block_fn: Callable = Block, @@ -295,16 +293,16 @@ def __init__( self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim))) - self.pos_drop = Dropout(pos_drop_rate) + self.pos_drop = mint.nn.Dropout(pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens, ) else: - self.patch_drop = nn.Identity() + self.patch_drop = mint.nn.Identity() - self.norm_pre = norm_layer((embed_dim,)) if pre_norm else nn.Identity() + self.norm_pre = norm_layer((embed_dim,)) if pre_norm else mint.nn.Identity() dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)] self.blocks = nn.CellList([ block_fn( @@ -315,10 +313,10 @@ def __init__( ) for i in range(depth) ]) - self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity() - self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity() - self.head_drop = Dropout(drop_rate) - self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.norm = norm_layer((embed_dim,)) if not use_fc_norm else mint.nn.Identity() + self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else mint.nn.Identity() + self.head_drop = mint.nn.Dropout(drop_rate) + self.head = mint.nn.Linear(embed_dim, num_classes) if num_classes > 0 else mint.nn.Identity() if weight_init: self._init_weights() @@ -333,7 +331,13 @@ def _init_weights(self): w_value.init_data() w.set_data(w_value.reshape(w.shape)) for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Dense): + # if isinstance(cell, mint.nn.Conv2d): + # cell.weight.set_data( + # initializer(TruncatedNormal(sigma=0.02), cell.weight.shape, cell.weight.dtype) + # ) + # if isinstance(cell, mint.nn.Linear) and cell.bias is not None: + # cell.bias.set_data(initializer(Zero(), cell.bias.shape, cell.bias.dtype)) + if isinstance(cell, mint.nn.Linear): cell.weight.set_data( initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype) ) @@ -341,12 +345,12 @@ def _init_weights(self): cell.bias.set_data( initializer('zeros', cell.bias.shape, cell.bias.dtype) ) - elif isinstance(cell, nn.LayerNorm): - cell.gamma.set_data( - initializer('ones', cell.gamma.shape, cell.gamma.dtype) + elif isinstance(cell, mint.nn.LayerNorm): + cell.weight.set_data( + initializer('ones', cell.weight.shape, cell.weight.dtype) ) - cell.beta.set_data( - initializer('zeros', cell.beta.shape, cell.beta.dtype) + cell.bias.set_data( + initializer('zeros', cell.bias.shape, cell.bias.dtype) ) def _pos_embed(self, x): @@ -358,7 +362,7 @@ def _pos_embed(self, x): (H, W), num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) - x = ops.reshape(x, (B, -1, C)) + x = mint.reshape(x, (B, -1, C)) else: pos_embed = self.pos_embed @@ -367,16 +371,16 @@ def _pos_embed(self, x): # position embedding does not overlap with class token, add then concat x = x + pos_embed if self.cls_token is not None: - cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = mint.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) cls_tokens = cls_tokens.astype(x.dtype) - x = ops.concat((cls_tokens, x), axis=1) + x = mint.concat((cls_tokens, x), dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: - cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + cls_tokens = mint.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) cls_tokens = cls_tokens.astype(x.dtype) - x = ops.concat((cls_tokens, x), axis=1) + x = mint.concat((cls_tokens, x), dim=1) x = x + pos_embed return self.pos_drop(x) diff --git a/mindcv/models/volo.py b/mindcv/models/volo.py index 536f9f2c4..e237573c2 100644 --- a/mindcv/models/volo.py +++ b/mindcv/models/volo.py @@ -6,14 +6,16 @@ import mindspore as ms import mindspore.common.initializer as init +import mindspore.mint as mint +import mindspore.mint.nn.functional as F import mindspore.nn as nn from mindspore import Parameter, Tensor from mindspore import dtype as mstype from mindspore import ops from .helpers import load_pretrained -from .layers.compatibility import Dropout from .layers.drop_path import DropPath +from .layers.extend_bmm import ExtendBatchMatMul from .layers.identity import Identity from .registry import register_model @@ -84,6 +86,7 @@ def int2tuple(a): init_weight[i, 0, x, y] = 1 self.weight = ms.Tensor(init_weight, ms.float16) + # todo self.conv_transpose2d = ops.Conv2DTranspose( self.ck, self.kernel_size, pad_mode="pad", pad=(self.padding[0], self.padding[0], self.padding[1], self.padding[1]), @@ -96,7 +99,7 @@ def construct(self, x: Tensor) -> Tensor: # assert l == self.h * self.w # print("construct-b", b, "construct-ck", ck, "construct-l", l) # print("self.h", self.h, "self.w", self.w) - x = ops.reshape(x, (b, ck, self.h, self.w)) + x = mint.reshape(x, (b, ck, self.h, self.w)) out = self.conv_transpose2d(x, self.weight, (b, self.c, self.output_size[0], self.output_size[1])) return out @@ -131,45 +134,44 @@ def __init__( self.stride = stride self.scale = qk_scale or head_dim**-0.5 - self.v = nn.Dense(dim, dim, has_bias=qkv_bias) - self.attn = nn.Dense(dim, kernel_size**4 * num_heads) + self.v = mint.nn.Linear(dim, dim, bias=qkv_bias) + self.attn = mint.nn.Linear(dim, kernel_size**4 * num_heads) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(dim, dim) - self.proj_drop = Dropout(p=proj_drop) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(dim, dim) + self.proj_drop = mint.nn.Dropout(p=proj_drop) - self.unfold = nn.Unfold(ksizes=[1, kernel_size, kernel_size, 1], strides=[1, stride, stride, 1], - rates=[1, 1, 1, 1]) - self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride) - self.softmax = nn.Softmax(axis=-1) - self.batch_mat_mul = ops.BatchMatMul() + self.unfold = mint.nn.Unfold(kernel_size=kernel_size, stride=stride, dilation=1, padding=0) + self.pool = mint.nn.AvgPool2d(kernel_size=stride, stride=stride) + self.softmax = mint.nn.Softmax(dim=-1) + self.batch_mat_mul = ExtendBatchMatMul() def construct(self, x: Tensor) -> Tensor: B, H, W, C = x.shape - v = ops.transpose(self.v(x), (0, 3, 1, 2)) # B, C, H, W + v = mint.permute(self.v(x), (0, 3, 1, 2)) # B, C, H, W h = int((H - 1) / self.stride + 1) w = int((W - 1) / self.stride + 1) - v = ops.pad(v, (1, 1, 1, 1)) + v = F.pad(v, (1, 1, 1, 1)) v = self.unfold(v) - v = ops.reshape(v, (B, self.num_heads, C // self.num_heads, self.kernel_size * self.kernel_size, h * w)) - v = ops.transpose(v, (0, 1, 4, 3, 2)) # B,H,N,kxk,C/H - - attn = self.pool(ops.transpose(x, (0, 3, 1, 2))) - attn = ops.transpose(attn, (0, 2, 3, 1)) - attn = ops.reshape(self.attn(attn), (B, h * w, self.num_heads, self.kernel_size * self.kernel_size, - self.kernel_size * self.kernel_size)) - attn = ops.transpose(attn, (0, 2, 1, 3, 4)) # B,H,N,kxk,kxk + v = mint.reshape(v, (B, self.num_heads, C // self.num_heads, self.kernel_size * self.kernel_size, h * w)) + v = mint.permute(v, (0, 1, 4, 3, 2)) # B,H,N,kxk,C/H + + attn = self.pool(mint.permute(x, (0, 3, 1, 2))) + attn = mint.permute(attn, (0, 2, 3, 1)) + attn = mint.reshape(self.attn(attn), (B, h * w, self.num_heads, self.kernel_size * self.kernel_size, + self.kernel_size * self.kernel_size)) + attn = mint.permute(attn, (0, 2, 1, 3, 4)) # B,H,N,kxk,kxk attn = attn * self.scale attn = self.softmax(attn) attn = self.attn_drop(attn) - x = ops.transpose(self.batch_mat_mul(attn, v), (0, 1, 4, 3, 2)) - x = ops.reshape(x, (B, C * self.kernel_size * self.kernel_size, h * w)) + x = mint.permute(self.batch_mat_mul(attn, v), (0, 1, 4, 3, 2)) + x = mint.reshape(x, (B, C * self.kernel_size * self.kernel_size, h * w)) fold = Fold(C, (H, W), self.kernel_size, padding=self.padding, stride=self.stride) x = fold(x) - x = self.proj(ops.transpose(x, (0, 2, 3, 1))) + x = self.proj(mint.permute(x, (0, 2, 3, 1))) x = self.proj_drop(x) return x @@ -195,8 +197,8 @@ def __init__( mlp_ratio=3., attn_drop=0.0, drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + act_layer=mint.nn.GELU, + norm_layer=mint.nn.LayerNorm, qkv_bias=False, qk_scale=None, ) -> None: @@ -230,16 +232,16 @@ def __init__( in_features, hidden_features=None, out_features=None, - act_layer=nn.GELU, + act_layer=mint.nn.GELU, drop=0.0, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Dense(in_features, hidden_features) + self.fc1 = mint.nn.Linear(in_features, hidden_features) self.act = act_layer() - self.fc2 = nn.Dense(hidden_features, out_features) - self.drop = Dropout(p=drop) + self.fc2 = mint.nn.Linear(hidden_features, out_features) + self.drop = mint.nn.Dropout(p=drop) def construct(self, x: Tensor) -> Tensor: x = self.fc1(x) @@ -267,28 +269,28 @@ def __init__( head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 - self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(dim, dim) - self.proj_drop = Dropout(p=proj_drop) - self.softmax = nn.Softmax(axis=-1) - self.batch_mat_mul_transpose = ops.BatchMatMul(transpose_b=True) - self.batch_mat_mul = ops.BatchMatMul() + self.qkv = mint.nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(dim, dim) + self.proj_drop = mint.nn.Dropout(p=proj_drop) + self.softmax = mint.nn.Softmax(dim=-1) + self.batch_mat_mul_transpose = ExtendBatchMatMul(transpose_b=True) + self.batch_mat_mul = ExtendBatchMatMul() def construct(self, x: Tensor) -> Tensor: B, H, W, C = x.shape qkv = self.qkv(x) - qkv = ops.reshape(qkv, (B, H * W, 3, self.num_heads, C // self.num_heads)) - qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) + qkv = mint.reshape(qkv, (B, H * W, 3, self.num_heads, C // self.num_heads)) + qkv = mint.permute(qkv, (2, 0, 3, 1, 4)) q, k, v = qkv[0], qkv[1], qkv[2] attn = self.batch_mat_mul_transpose(q, k) * self.scale attn = self.softmax(attn) attn = self.attn_drop(attn) - x = ops.transpose(self.batch_mat_mul(attn, v), (0, 2, 1, 3)) - x = ops.reshape(x, (B, H, W, C)) + x = mint.permute(self.batch_mat_mul(attn, v), (0, 2, 1, 3)) + x = mint.reshape(x, (B, H, W, C)) x = self.proj(x) x = self.proj_drop(x) return x @@ -308,8 +310,8 @@ def __init__( qk_scale=None, attn_drop=0.0, drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + act_layer=mint.nn.GELU, + norm_layer=mint.nn.LayerNorm, ) -> None: super().__init__() self.norm1 = norm_layer([dim]) @@ -356,31 +358,31 @@ def __init__( self.head_dim = head_dim self.scale = qk_scale or head_dim**-0.5 - self.kv = nn.Dense(dim, self.head_dim * self.num_heads * 2, has_bias=qkv_bias) - self.q = nn.Dense(dim, self.head_dim * self.num_heads, has_bias=qkv_bias) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(self.head_dim * self.num_heads, dim) - self.proj_drop = Dropout(p=proj_drop) - self.batch_mat_mul_transpose = ops.BatchMatMul(transpose_b=True) - self.batch_mat_mul = ops.BatchMatMul() - self.softmax = nn.Softmax(axis=-1) + self.kv = mint.nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias) + self.q = mint.nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(self.head_dim * self.num_heads, dim) + self.proj_drop = mint.nn.Dropout(p=proj_drop) + self.batch_mat_mul_transpose = ExtendBatchMatMul(transpose_b=True) + self.batch_mat_mul = ExtendBatchMatMul() + self.softmax = mint.nn.Softmax(dim=-1) def construct(self, x: Tensor) -> Tensor: B, N, C = x.shape kv = self.kv(x) - kv = ops.reshape(kv, (B, N, 2, self.num_heads, - self.head_dim)) - kv = ops.transpose(kv, (2, 0, 3, 1, 4)) + kv = mint.reshape(kv, (B, N, 2, self.num_heads, + self.head_dim)) + kv = mint.permute(kv, (2, 0, 3, 1, 4)) k, v = kv[0], kv[1] q = self.q(x[:, :1, :]) - q = ops.reshape(q, (B, self.num_heads, 1, self.head_dim)) + q = mint.reshape(q, (B, self.num_heads, 1, self.head_dim)) attn = self.batch_mat_mul_transpose(q * self.scale, k) attn = self.softmax(attn) attn = self.attn_drop(attn) - cls_embed = ops.transpose(self.batch_mat_mul(attn, v), (0, 2, 1, 3)) - cls_embed = ops.reshape(cls_embed, (B, 1, self.head_dim * self.num_heads)) + cls_embed = mint.permute(self.batch_mat_mul(attn, v), (0, 2, 1, 3)) + cls_embed = mint.reshape(cls_embed, (B, 1, self.head_dim * self.num_heads)) cls_embed = self.proj(cls_embed) cls_embed = self.proj_drop(cls_embed) return cls_embed @@ -403,8 +405,8 @@ def __init__( drop=0.0, attn_drop=0.0, drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, + act_layer=mint.nn.GELU, + norm_layer=mint.nn.LayerNorm, ) -> None: super().__init__() self.norm1 = norm_layer([dim]) @@ -425,7 +427,7 @@ def construct(self, x: Tensor) -> Tensor: cls_embed = x[:, :1] cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x))) cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed))) - x = ops.concat([cls_embed, x[:, 1:]], 1) + x = mint.concat([cls_embed, x[:, 1:]], 1) return x @@ -459,24 +461,20 @@ def __init__( self.stem_conv = stem_conv if stem_conv: self.conv = nn.SequentialCell( - nn.Conv2d(in_channels, hidden_dim, 7, stem_stride, - pad_mode='pad', padding=3), # 112x112 - nn.BatchNorm2d(hidden_dim), - nn.ReLU(), - nn.Conv2d(hidden_dim, hidden_dim, 3, 1, - pad_mode='pad', padding=1), # 112x112 - nn.BatchNorm2d(hidden_dim), - nn.ReLU(), - nn.Conv2d(hidden_dim, hidden_dim, 3, 1, - pad_mode='pad', padding=1), # 112x112 - nn.BatchNorm2d(hidden_dim), - nn.ReLU(), + mint.nn.Conv2d(in_channels, hidden_dim, 7, stem_stride, padding=3, bias=False), # 112x112 + mint.nn.BatchNorm2d(hidden_dim), + mint.nn.ReLU(), + mint.nn.Conv2d(hidden_dim, hidden_dim, 3, 1, padding=1, bias=False), # 112x112 + mint.nn.BatchNorm2d(hidden_dim), + mint.nn.ReLU(), + mint.nn.Conv2d(hidden_dim, hidden_dim, 3, 1, padding=1, bias=False), # 112x112 + mint.nn.BatchNorm2d(hidden_dim), + mint.nn.ReLU(), ) - - self.proj = nn.Conv2d(hidden_dim, - embed_dim, - kernel_size=patch_size // stem_stride, - stride=patch_size // stem_stride, has_bias=True) + else: + self.conv = None + self.proj = mint.nn.Conv2d( + hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride) self.num_patches = (img_size // patch_size) * (img_size // patch_size) def construct(self, x: Tensor) -> Tensor: @@ -492,13 +490,12 @@ class Downsample(nn.Cell): """ def __init__(self, in_embed_dim, out_embed_dim, patch_size,) -> None: super().__init__() - self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, - kernel_size=patch_size, stride=patch_size, has_bias=True) + self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size) def construct(self, x: Tensor) -> Tensor: - x = ops.transpose(x, (0, 3, 1, 2)) + x = mint.permute(x, (0, 3, 1, 2)) x = self.proj(x) # B, C, H, W - x = ops.transpose(x, (0, 2, 3, 1)) + x = mint.permute(x, (0, 2, 3, 1)) return x @@ -590,7 +587,7 @@ def __init__( drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, - norm_layer=nn.LayerNorm, + norm_layer=mint.nn.LayerNorm, post_layers=None, return_mean=False, return_dense=True, @@ -608,11 +605,11 @@ def __init__( embed_dim=embed_dims[0]) # inital positional encoding, we add positional encoding after outlooker blocks self.pos_embed = Parameter( - ops.zeros((1, img_size // patch_size // pooling_scale, - img_size // patch_size // pooling_scale, - embed_dims[-1]), mstype.float32)) + mint.zeros((1, img_size // patch_size // pooling_scale, + img_size // patch_size // pooling_scale, + embed_dims[-1]), dtype=mstype.float32)) - self.pos_drop = Dropout(p=drop_rate) + self.pos_drop = mint.nn.Dropout(p=drop_rate) # set the main block in network network = [] @@ -657,7 +654,7 @@ def __init__( norm_layer=norm_layer) for i in range(len(post_layers)) ]) - self.cls_token = Parameter(ops.zeros((1, 1, embed_dims[-1]), mstype.float32)) + self.cls_token = Parameter(mint.zeros((1, 1, embed_dims[-1]), dtype=mstype.float32)) self.cls_token.set_data(init.initializer(init.TruncatedNormal(sigma=.02), self.cls_token.data.shape)) # set output type @@ -671,13 +668,13 @@ def __init__( self.beta = 1.0 assert return_dense, "return all tokens if mix_token is enabled" if return_dense: - self.aux_head = nn.Dense( + self.aux_head = mint.nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 else Identity() self.norm = norm_layer([embed_dims[-1]]) # Classifier head - self.head = nn.Dense( + self.head = mint.nn.Linear( embed_dims[-1], num_classes) if num_classes > 0 else Identity() self.pos_embed.set_data(init.initializer(init.TruncatedNormal(sigma=.02), self.pos_embed.data.shape)) @@ -685,19 +682,19 @@ def __init__( def _init_weights(self) -> None: for name, m in self.cells_and_names(): - if isinstance(m, nn.Dense): + if isinstance(m, mint.nn.Linear): m.weight.set_data(init.initializer(init.TruncatedNormal(sigma=.02), m.weight.data.shape)) if m.bias is not None: m.bias.set_data(init.initializer(init.Constant(0), m.bias.shape)) - elif isinstance(m, nn.LayerNorm): - m.gamma.set_data(init.initializer(init.Constant(1), m.gamma.shape)) - m.beta.set_data(init.initializer(init.Constant(0), m.beta.shape)) + elif isinstance(m, mint.nn.LayerNorm): + m.weight.set_data(init.initializer(init.Constant(1), m.weight.shape)) + m.bias.set_data(init.initializer(init.Constant(0), m.bias.shape)) def forward_embeddings(self, x: Tensor) -> Tensor: # patch embedding x = self.patch_embed(x) # B,C,H,W-> B,H,W,C - x = ops.transpose(x, (0, 2, 3, 1)) + x = mint.permute(x, (0, 2, 3, 1)) return x def forward_tokens(self, x: Tensor) -> Tensor: @@ -708,14 +705,14 @@ def forward_tokens(self, x: Tensor) -> Tensor: x = block(x) B, H, W, C = x.shape - x = ops.reshape(x, (B, -1, C)) + x = mint.reshape(x, (B, -1, C)) return x def forward_cls(self, x: Tensor) -> Tensor: # B, N, C = x.shape - cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) - x = ops.Cast()(x, cls_tokens.dtype) - x = ops.concat([cls_tokens, x], 1) + cls_tokens = mint.broadcast_to(self.cls_token, (x.shape[0], -1, -1)) + x = x.to(cls_tokens.dtype) + x = mint.concat([cls_tokens, x], 1) for block in self.post_network: x = block(x) return x @@ -733,7 +730,7 @@ def construct(self, x: Tensor) -> Tensor: x = self.norm(x) if self.return_mean: # if no class token, return mean - return self.head(ops.mean(x, 1)) + return self.head(mint.mean(x, 1)) x_cls = self.head(x[:, 0]) if not self.return_dense: diff --git a/mindcv/models/xception.py b/mindcv/models/xception.py index 445cbcde4..277759240 100644 --- a/mindcv/models/xception.py +++ b/mindcv/models/xception.py @@ -4,11 +4,10 @@ """ import mindspore.common.initializer as init -from mindspore import Tensor, nn, ops +from mindspore import Tensor, mint, nn from .helpers import load_pretrained from .layers import GlobalAvgPooling -from .layers.compatibility import Dropout from .registry import register_model __all__ = [ @@ -44,9 +43,10 @@ def __init__( padding: int = 0, ): super().__init__() - self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode="pad", - padding=padding) - self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode="valid") + self.conv1 = mint.nn.Conv2d( + in_channels, in_channels, kernel_size, stride, groups=in_channels, padding=padding, bias=False + ) + self.pointwise = mint.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) def construct(self, x): x = self.conv1(x) @@ -69,37 +69,37 @@ def __init__( super().__init__() if out_filters != in_filters or strides != 1: - self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode="valid", has_bias=False) - self.skipbn = nn.BatchNorm2d(out_filters) + self.skip = mint.nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) + self.skipbn = mint.nn.BatchNorm2d(out_filters) else: self.skip = None - self.relu = nn.ReLU() + self.relu = mint.nn.ReLU() rep = [] filters = in_filters if grow_first: - rep.append(nn.ReLU()) + rep.append(mint.nn.ReLU()) rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) - rep.append(nn.BatchNorm2d(out_filters)) + rep.append(mint.nn.BatchNorm2d(out_filters)) filters = out_filters for _ in range(reps - 1): - rep.append(nn.ReLU()) + rep.append(mint.nn.ReLU()) rep.append(SeparableConv2d(filters, filters, kernel_size=3, stride=1, padding=1)) - rep.append(nn.BatchNorm2d(filters)) + rep.append(mint.nn.BatchNorm2d(filters)) if not grow_first: - rep.append(nn.ReLU()) + rep.append(mint.nn.ReLU()) rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) - rep.append(nn.BatchNorm2d(out_filters)) + rep.append(mint.nn.BatchNorm2d(out_filters)) if not start_with_relu: rep = rep[1:] else: - rep[0] = nn.ReLU() + rep[0] = mint.nn.ReLU() if strides != 1: - rep.append(nn.MaxPool2d(3, strides, pad_mode="same")) + rep.append(mint.nn.MaxPool2d(3, strides, 1)) self.rep = nn.SequentialCell(*rep) def construct(self, inp): @@ -110,7 +110,7 @@ def construct(self, inp): skip = self.skipbn(skip) else: skip = inp - x = ops.add(x, skip) + x = mint.add(x, skip) return x @@ -131,11 +131,11 @@ def __init__( super().__init__() self.num_classes = num_classes blocks = [] - self.conv1 = nn.Conv2d(in_channels, 32, 3, 2, pad_mode="valid") - self.bn1 = nn.BatchNorm2d(32) - self.relu = nn.ReLU() - self.conv2 = nn.Conv2d(32, 64, 3, pad_mode="valid") - self.bn2 = nn.BatchNorm2d(64) + self.conv1 = mint.nn.Conv2d(in_channels, 32, 3, 2, bias=False) + self.bn1 = mint.nn.BatchNorm2d(32) + self.relu = mint.nn.ReLU() + self.conv2 = mint.nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = mint.nn.BatchNorm2d(64) # Entry flow blocks.append(Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)) @@ -152,13 +152,13 @@ def __init__( self.blocks = nn.SequentialCell(blocks) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) - self.bn3 = nn.BatchNorm2d(1536) + self.bn3 = mint.nn.BatchNorm2d(1536) self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) - self.bn4 = nn.BatchNorm2d(2048) + self.bn4 = mint.nn.BatchNorm2d(2048) self.pool = GlobalAvgPooling() - self.dropout = Dropout(p=0.5) - self.classifier = nn.Dense(2048, num_classes) + self.dropout = mint.nn.Dropout(p=0.5) + self.classifier = mint.nn.Linear(2048, num_classes) self._initialize_weights() @@ -193,12 +193,13 @@ def construct(self, x: Tensor) -> Tensor: def _initialize_weights(self) -> None: """Initialize weights for cells.""" for _, cell in self.cells_and_names(): - if isinstance(cell, nn.Conv2d): + if isinstance(cell, mint.nn.Conv2d): cell.weight.set_data(init.initializer(init.XavierUniform(), cell.weight.shape, cell.weight.dtype)) - elif isinstance(cell, nn.Dense): - cell.weight.set_data(init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype)) + elif isinstance(cell, mint.nn.Linear): + cell.weight.set_data( + init.initializer(init.Normal(0.01, 0), cell.weight.shape, cell.weight.dtype)) if cell.bias is not None: - cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.weight.dtype)) + cell.bias.set_data(init.initializer(init.Constant(0), cell.bias.shape, cell.bias.dtype)) @register_model diff --git a/mindcv/models/xcit.py b/mindcv/models/xcit.py index 2c3b6966e..d36e7860b 100644 --- a/mindcv/models/xcit.py +++ b/mindcv/models/xcit.py @@ -10,11 +10,11 @@ import mindspore.common.initializer as weight_init from mindspore import Parameter, Tensor from mindspore import dtype as mstype -from mindspore import nn, numpy, ops +from mindspore import mint, nn, numpy, ops from .helpers import _ntuple, load_pretrained -from .layers.compatibility import Dropout from .layers.drop_path import DropPath +from .layers.extend_bmm import ExtendBatchMatMul from .layers.mlp import Mlp from .registry import register_model @@ -55,8 +55,8 @@ def __init__(self, temperature=10000 ) -> None: super().__init__() - self.token_projection = nn.Conv2d( - hidden_dim * 2, dim, kernel_size=1, has_bias=True) + self.token_projection = mint.nn.Conv2d( + hidden_dim * 2, dim, kernel_size=1, bias=True) self.scale = 2 * np.pi self.temperature = temperature self.hidden_dim = hidden_dim @@ -76,15 +76,15 @@ def construct(self, B, H, W) -> Tensor: pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t - pos_x = ops.stack((ops.sin(pos_x[:, :, :, 0::2]), - ops.cos(pos_x[:, :, :, 1::2])), 4) + pos_x = mint.stack((mint.sin(pos_x[:, :, :, 0::2]), + mint.cos(pos_x[:, :, :, 1::2])), 4) x1, x2, x3, x4, x5 = pos_x.shape - pos_x = ops.reshape(pos_x, (x1, x2, x3, x4 * x5)) - pos_y = ops.stack((ops.sin(pos_y[:, :, :, 0::2]), - ops.cos(pos_y[:, :, :, 1::2])), 4) + pos_x = mint.reshape(pos_x, (x1, x2, x3, x4 * x5)) + pos_y = mint.stack((mint.sin(pos_y[:, :, :, 0::2]), + mint.cos(pos_y[:, :, :, 1::2])), 4) y1, y2, y3, y4, y5 = pos_y.shape - pos_y = ops.reshape(pos_y, (y1, y2, y3, y4 * y5)) - pos = ops.transpose(ops.concat((pos_y, pos_x), 3), (0, 3, 1, 2)) + pos_y = mint.reshape(pos_y, (y1, y2, y3, y4 * y5)) + pos = mint.permute(mint.concat((pos_y, pos_x), 3), (0, 3, 1, 2)) pos = self.token_projection(pos) return pos @@ -92,10 +92,10 @@ def construct(self, B, H, W) -> Tensor: def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.SequentialCell([ - nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, pad_mode='pad', has_bias=False + mint.nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False ), - nn.BatchNorm2d(out_planes) + mint.nn.BatchNorm2d(out_planes) ]) @@ -121,19 +121,19 @@ def __init__(self, if patch_size[0] == 16: self.proj = nn.SequentialCell([ conv3x3(3, embed_dim // 8, 2), - nn.GELU(), + mint.nn.GELU(), conv3x3(embed_dim // 8, embed_dim // 4, 2), - nn.GELU(), + mint.nn.GELU(), conv3x3(embed_dim // 4, embed_dim // 2, 2), - nn.GELU(), + mint.nn.GELU(), conv3x3(embed_dim // 2, embed_dim, 2), ]) elif patch_size[0] == 8: self.proj = nn.SequentialCell([ conv3x3(3, embed_dim // 4, 2), - nn.GELU(), + mint.nn.GELU(), conv3x3(embed_dim // 4, embed_dim // 2, 2), - nn.GELU(), + mint.nn.GELU(), conv3x3(embed_dim // 2, embed_dim, 2), ]) else: @@ -143,8 +143,8 @@ def __init__(self, def construct(self, x, padding_size=None) -> Tensor: x = self.proj(x) B, C, Hp, Wp = x.shape - x = ops.reshape(x, (B, C, Hp * Wp)) - x = x.transpose(0, 2, 1) + x = mint.reshape(x, (B, C, Hp * Wp)) + x = mint.permute(x, (0, 2, 1)) return x, (Hp, Wp) @@ -156,28 +156,30 @@ class LPI(nn.Cell): Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=mint.nn.GELU, drop=0., kernel_size=3) -> None: super().__init__() out_features = out_features or in_features padding = kernel_size // 2 - self.conv1 = nn.Conv2d(in_features, out_features, kernel_size=kernel_size, - padding=padding, pad_mode='pad', group=out_features, has_bias=True) + self.conv1 = mint.nn.Conv2d( + in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, bias=True + ) self.act = act_layer() - self.bn = nn.BatchNorm2d(in_features) - self.conv2 = nn.Conv2d(in_features, out_features, kernel_size=kernel_size, - padding=padding, pad_mode='pad', group=out_features, has_bias=True) + self.bn = mint.nn.BatchNorm2d(in_features) + self.conv2 = mint.nn.Conv2d( + in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features, bias=True + ) def construct(self, x, H, W) -> Tensor: B, N, C = x.shape - x = ops.reshape(ops.transpose(x, (0, 2, 1)), (B, C, H, W)) + x = mint.reshape(mint.permute(x, (0, 2, 1)), (B, C, H, W)) x = self.conv1(x) x = self.act(x) x = self.bn(x) x = self.conv2(x) - x = ops.transpose(ops.reshape(x, (B, C, N)), (0, 2, 1)) + x = mint.permute(mint.reshape(x, (B, C, N)), (0, 2, 1)) return x @@ -192,33 +194,33 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 - self.qkv = nn.Dense( - in_channels=dim, out_channels=dim * 3, has_bias=qkv_bias) - self.attn_drop = Dropout(p=attn_drop) - self.proj = nn.Dense(in_channels=dim, out_channels=dim) - self.proj_drop = Dropout(p=proj_drop) - self.softmax = nn.Softmax(axis=-1) + self.qkv = mint.nn.Linear( + in_features=dim, out_features=dim * 3, bias=qkv_bias) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.proj = mint.nn.Linear(in_features=dim, out_features=dim) + self.proj_drop = mint.nn.Dropout(p=proj_drop) + self.softmax = mint.nn.Softmax(dim=-1) - self.attn_matmul_v = ops.BatchMatMul() + self.attn_matmul_v = ExtendBatchMatMul() def construct(self, x: Tensor) -> Tensor: B, N, C = x.shape qkv = self.qkv(x) - qkv = ops.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads)) - qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) + qkv = mint.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads)) + qkv = mint.permute(qkv, (2, 0, 3, 1, 4)) q, k, v = ops.unstack(qkv, axis=0) qc = q[:, :, 0:1] attn_cls = (qc * k).sum(-1) * self.scale attn_cls = self.softmax(attn_cls) attn_cls = self.attn_drop(attn_cls) - attn_cls = ops.expand_dims(attn_cls, 2) + attn_cls = mint.unsqueeze(attn_cls, 2) cls_tkn = self.attn_matmul_v(attn_cls, v) - cls_tkn = ops.transpose(cls_tkn, (0, 2, 1, 3)) - cls_tkn = ops.reshape(cls_tkn, (B, 1, C)) + cls_tkn = mint.permute(cls_tkn, (0, 2, 1, 3)) + cls_tkn = mint.reshape(cls_tkn, (B, 1, C)) cls_tkn = self.proj(cls_tkn) - x = ops.concat((self.proj_drop(cls_tkn), x[:, 1:]), axis=1) + x = mint.concat((self.proj_drop(cls_tkn), x[:, 1:]), dim=1) return x @@ -227,7 +229,7 @@ class ClassAttentionBlock(nn.Cell): """ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., - attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=None, + attn_drop=0., drop_path=0., act_layer=mint.nn.GELU, norm_layer=mint.nn.LayerNorm, eta=None, tokens_norm=False): super().__init__() self.norm1 = norm_layer([dim]) @@ -237,7 +239,7 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, ) self.drop_path = DropPath( - drop_path) if drop_path > 0. else ops.Identity() + drop_path) if drop_path > 0. else mint.nn.Identity() self.norm2 = norm_layer([dim]) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, @@ -246,9 +248,9 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, # LayerScale Initialization (no layerscale when None) if eta is not None: self.gamma1 = Parameter( - eta * ops.Ones()((dim), mstype.float32), requires_grad=True) + eta * mint.ones((dim), dtype=mstype.float32), requires_grad=True) self.gamma2 = Parameter( - eta * ops.Ones()((dim), mstype.float32), requires_grad=True) + eta * mint.ones((dim), dtype=mstype.float32), requires_grad=True) else: self.gamma1, self.gamma2 = 1.0, 1.0 @@ -266,7 +268,7 @@ def construct(self, x, H, W, mask=None): x_res = x cls_token = x[:, 0:1] cls_token = self.gamma2 * self.mlp(cls_token) - x = ops.concat((cls_token, x[:, 1:]), axis=1) + x = mint.concat((cls_token.to(x.dtype), x[:, 1:]), dim=1) x = x_res + self.drop_path(x) return x @@ -282,34 +284,34 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. super().__init__() self.num_heads = num_heads self.temperature = Parameter( - ops.Ones()((num_heads, 1, 1), mstype.float32)) - self.qkv = nn.Dense( - in_channels=dim, out_channels=dim * 3, has_bias=qkv_bias) - self.q_matmul_k = ops.BatchMatMul(transpose_b=True) - self.softmax = nn.Softmax(axis=-1) - self.attn_drop = Dropout(p=attn_drop) - self.attn_matmul_v = ops.BatchMatMul() - self.proj = nn.Dense(in_channels=dim, out_channels=dim) - self.proj_drop = Dropout(p=proj_drop) + mint.ones((num_heads, 1, 1), dtype=mstype.float32)) + self.qkv = mint.nn.Linear( + in_features=dim, out_features=dim * 3, bias=qkv_bias) + self.q_matmul_k = ExtendBatchMatMul(transpose_b=True) + self.softmax = mint.nn.Softmax(dim=-1) + self.attn_drop = mint.nn.Dropout(p=attn_drop) + self.attn_matmul_v = ExtendBatchMatMul() + self.proj = mint.nn.Linear(in_features=dim, out_features=dim) + self.proj_drop = mint.nn.Dropout(p=proj_drop) def construct(self, x): B, N, C = x.shape - qkv = ops.reshape( + qkv = mint.reshape( self.qkv(x), (B, N, 3, self.num_heads, C // self.num_heads)) - qkv = ops.transpose(qkv, (2, 0, 3, 1, 4)) + qkv = mint.permute(qkv, (2, 0, 3, 1, 4)) q, k, v = ops.unstack(qkv, axis=0) - q = ops.transpose(q, (0, 1, 3, 2)) - k = ops.transpose(k, (0, 1, 3, 2)) - v = ops.transpose(v, (0, 1, 3, 2)) + q = mint.permute(q, (0, 1, 3, 2)) + k = mint.permute(k, (0, 1, 3, 2)) + v = mint.permute(v, (0, 1, 3, 2)) attn = self.q_matmul_k(q, k) * self.temperature attn = self.softmax(attn) attn = self.attn_drop(attn) x = self.attn_matmul_v(attn, v) - x = ops.transpose(x, (0, 3, 1, 2)) - x = ops.reshape(x, (B, N, C)) + x = mint.permute(x, (0, 3, 1, 2)) + x = mint.reshape(x, (B, N, C)) x = self.proj(x) x = self.proj_drop(x) return x @@ -317,7 +319,7 @@ def construct(self, x): class XCABlock(nn.Cell): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., - attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + attn_drop=0., drop_path=0., act_layer=mint.nn.GELU, norm_layer=mint.nn.LayerNorm, num_tokens=196, eta=None): super().__init__() self.norm1 = norm_layer([dim]) @@ -326,7 +328,7 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, proj_drop=drop ) self.drop_path = DropPath( - drop_path) if drop_path > 0. else nn.Identity() + drop_path) if drop_path > 0. else mint.nn.Identity() self.norm2 = norm_layer([dim]) mlp_hidden_dim = int(dim * mlp_ratio) @@ -337,11 +339,11 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, self.local_mp = LPI(in_features=dim, act_layer=act_layer) self.gamma1 = Parameter( - eta * ops.ones(dim, mstype.float32), requires_grad=True) + eta * mint.ones(dim, dtype=mstype.float32), requires_grad=True) self.gamma2 = Parameter( - eta * ops.ones(dim, mstype.float32), requires_grad=True) + eta * mint.ones(dim, dtype=mstype.float32), requires_grad=True) self.gamma3 = Parameter( - eta * ops.ones(dim, mstype.float32), requires_grad=True) + eta * mint.ones(dim, dtype=mstype.float32), requires_grad=True) def construct(self, x, H, W): x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) @@ -399,7 +401,7 @@ def __init__(self, self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim - norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6) + norm_layer = norm_layer or partial(mint.nn.LayerNorm, eps=1e-6) self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim, patch_size=patch_size) @@ -407,8 +409,8 @@ def __init__(self, num_patches = self.patch_embed.num_patches self.cls_token = Parameter( - ops.zeros((1, 1, embed_dim), mstype.float32)) - self.pos_drop = Dropout(p=drop_rate) + mint.zeros((1, 1, embed_dim), dtype=mstype.float32)) + self.pos_drop = mint.nn.Dropout(p=drop_rate) dpr = [drop_path_rate for i in range(depth)] self.blocks = nn.CellList([ @@ -425,8 +427,8 @@ def __init__(self, eta=eta, tokens_norm=tokens_norm) for i in range(cls_attn_layers)]) self.norm = norm_layer([embed_dim]) - self.head = nn.Dense( - in_channels=embed_dim, out_channels=num_classes) if num_classes > 0 else ops.Identity() + self.head = mint.nn.Linear( + in_features=embed_dim, out_features=num_classes) if num_classes > 0 else mint.nn.Identity() self.pos_embeder = PositionalEncodingFourier(dim=embed_dim) self.use_pos = use_pos @@ -439,31 +441,30 @@ def __init__(self, def _init_weights(self) -> None: for name, m in self.cells_and_names(): - if isinstance(m, nn.Dense): + if isinstance(m, mint.nn.Linear): m.weight = weight_init.initializer(weight_init.TruncatedNormal( sigma=0.02), m.weight.shape, mindspore.float32) if m.bias is not None: m.bias.set_data(weight_init.initializer( weight_init.Constant(0), m.bias.shape)) - elif isinstance(m, nn.LayerNorm): - m.beta.set_data(weight_init.initializer( - weight_init.Constant(0), m.beta.shape)) - m.gamma.set_data(weight_init.initializer( - weight_init.Constant(1), m.gamma.shape)) + elif isinstance(m, mint.nn.LayerNorm): + m.bias.set_data(weight_init.initializer( + weight_init.Constant(0), m.bias.shape)) + m.weight.set_data(weight_init.initializer( + weight_init.Constant(1), m.weight.shape)) def forward_features(self, x): B, C, H, W = x.shape x, (Hp, Wp) = self.patch_embed(x) if self.use_pos: - pos_encoding = self.pos_embeder(B, Hp, Wp).reshape( - B, -1, x.shape[1]).transpose(0, 2, 1) + pos_encoding = mint.permute(mint.reshape( + self.pos_embeder(B, Hp, Wp), (B, -1, x.shape[1])), (0, 2, 1)) x = x + pos_encoding x = self.pos_drop(x) for blk in self.blocks: x = blk(x, Hp, Wp) - cls_tokens = ops.broadcast_to(self.cls_token, (B, -1, -1)) - cls_tokens = ops.cast(cls_tokens, x.dtype) - x = ops.concat((cls_tokens, x), 1) + cls_tokens = mint.broadcast_to(self.cls_token, (B, -1, -1)) + x = mint.concat((cls_tokens.to(x.dtype), x), 1) for blk in self.cls_attn_blocks: x = blk(x, Hp, Wp) @@ -483,7 +484,7 @@ def xcit_tiny_12_p16_224(pretrained: bool = False, num_classes: int = 1000, in_c default_cfg = default_cfgs['xcit_tiny_12_p16_224'] model = XCiT( patch_size=16, num_classes=num_classes, embed_dim=192, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, epsilon=1e-6), eta=1.0, tokens_norm=True, **kwargs) + norm_layer=partial(mint.nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=True, **kwargs) if pretrained: load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) diff --git a/mindcv/utils/amp.py b/mindcv/utils/amp.py index 51677be62..ddbc2c293 100644 --- a/mindcv/utils/amp.py +++ b/mindcv/utils/amp.py @@ -1,6 +1,6 @@ """ auto mixed precision related functions """ from mindspore import dtype as mstype -from mindspore import nn +from mindspore import mint, nn from mindspore.ops import functional as F AMP_WHITE_LIST = ( @@ -11,6 +11,10 @@ nn.Conv1dTranspose, nn.Conv2dTranspose, nn.Conv3dTranspose, + mint.nn.Linear, + mint.nn.Conv2d, + mint.nn.Conv3d, + mint.nn.ConvTranspose2d, ) AMP_BLACK_LIST = ( @@ -18,6 +22,10 @@ nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm, + mint.nn.BatchNorm1d, + mint.nn.BatchNorm2d, + mint.nn.BatchNorm3d, + mint.nn.LayerNorm, ) diff --git a/mindcv/utils/trainer_factory.py b/mindcv/utils/trainer_factory.py index db47a48e6..22565906c 100644 --- a/mindcv/utils/trainer_factory.py +++ b/mindcv/utils/trainer_factory.py @@ -2,13 +2,13 @@ from typing import Optional, Union import mindspore as ms -from mindspore import Tensor, context +from mindspore import Tensor, amp, context from mindspore import dtype as mstype from mindspore import nn from mindspore.ops import functional as F from mindspore.train import DynamicLossScaleManager, FixedLossScaleManager, Model -from .amp import auto_mixed_precision +from .amp import AMP_BLACK_LIST, auto_mixed_precision from .train_step import TrainStep __all__ = [ @@ -121,12 +121,13 @@ def create_trainer( raise ValueError("`gradient_accumulation_steps` must be >= 1!") if not require_customized_train_step(ema, clip_grad, gradient_accumulation_steps, amp_cast_list): + network = amp.custom_mixed_precision(network, black_list=list(AMP_BLACK_LIST)) mindspore_kwargs = dict( network=network, loss_fn=loss, optimizer=optimizer, metrics=metrics, - amp_level=amp_level, + amp_level="O0", ) if loss_scale_type.lower() == "fixed": mindspore_kwargs["loss_scale_manager"] = FixedLossScaleManager( diff --git a/train.py b/train.py index 21efe4205..87a1d4484 100644 --- a/train.py +++ b/train.py @@ -29,8 +29,8 @@ def main(): args = parse_args() ms.set_context(mode=args.mode) - if args.mode == ms.GRAPH_MODE: - ms.set_context(jit_config={"jit_level": "O2"}) + # if args.mode == ms.GRAPH_MODE: + # ms.set_context(jit_config={"jit_level": "O2"}) if args.distribute: init() rank_id, device_num = get_rank(), get_group_size()