From a9bc4f4c705c348dc3b5254820c9e00cae4a7f43 Mon Sep 17 00:00:00 2001 From: alien_0119 Date: Fri, 10 Oct 2025 09:34:18 +0800 Subject: [PATCH] add hgnet_v2 --- mindone/transformers/__init__.py | 1 + mindone/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../transformers/models/auto/modeling_auto.py | 3 + .../transformers/models/hgnet_v2/__init__.py | 17 + .../models/hgnet_v2/modeling_hgnet_v2.py | 480 ++++++++++++++++++ .../models/hgnet_v2/__init__.py | 0 .../models/hgnet_v2/test_modeing_hgnet_v2.py | 223 ++++++++ 8 files changed, 727 insertions(+) create mode 100644 mindone/transformers/models/hgnet_v2/__init__.py create mode 100644 mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py create mode 100644 tests/transformers_tests/models/hgnet_v2/__init__.py create mode 100644 tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..38051f3f73 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -666,6 +666,7 @@ HeliumModel, HeliumPreTrainedModel, ) +from .models.hgnet_v2 import HGNetV2Backbone, HGNetV2ForImageClassification, HGNetV2PreTrainedModel from .models.hiera import ( HieraBackbone, HieraForImageClassification, diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 28e7ced270..9bc966d368 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -101,6 +101,7 @@ granitemoe, granitemoeshared, groupvit, + hgnet_v2, hiera, hubert, idefics, diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..4171b2cc83 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -124,6 +124,7 @@ ("granitemoeshared", "GraniteMoeSharedConfig"), ("groupvit", "GroupViTConfig"), ("helium", "HeliumConfig"), + ("hgnet_v2", "HGNetV2Config"), ("hiera", "HieraConfig"), ("hubert", "HubertConfig"), ("ibert", "IBertConfig"), @@ -392,6 +393,7 @@ ("granitemoeshared", "GraniteMoeSharedMoe"), ("groupvit", "GroupViT"), ("helium", "Helium"), + ("hgnet_v2", "HGNet-V2"), ("hiera", "Hiera"), ("hubert", "Hubert"), ("ibert", "I-BERT"), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..0a7d4602c9 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -118,6 +118,7 @@ ("groupvit", "GroupViTModel"), ("grounding-dino", "GroundingDinoModel"), ("helium", "HeliumModel"), + ("hgnet_v2", "HGNetV2Backbone"), ("hiera", "HieraModel"), ("hubert", "HubertModel"), ("ibert", "IBertModel"), @@ -583,6 +584,7 @@ ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), ("efficientnet", "EfficientNetForImageClassification"), ("focalnet", "FocalNetForImageClassification"), + ("hgnet_v2", "HGNetV2ForImageClassification"), ("hiera", "HieraForImageClassification"), ("ijepa", "IJepaForImageClassification"), ("imagegpt", "ImageGPTForImageClassification"), @@ -1212,6 +1214,7 @@ ("dinov2", "Dinov2Backbone"), ("dinov2_with_registers", "Dinov2WithRegistersBackbone"), ("focalnet", "FocalNetBackbone"), + ("hgnet_v2", "HGNetV2Backbone"), ("hiera", "HieraBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"), ("pvt_v2", "PvtV2Backbone"), diff --git a/mindone/transformers/models/hgnet_v2/__init__.py b/mindone/transformers/models/hgnet_v2/__init__.py new file mode 100644 index 0000000000..57a9b2b52c --- /dev/null +++ b/mindone/transformers/models/hgnet_v2/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_hgnet_v2 import * diff --git a/mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py b/mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py new file mode 100644 index 0000000000..32bdc3819e --- /dev/null +++ b/mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py @@ -0,0 +1,480 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/hgnet_v2/modular_hgnet_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_hgnet_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from transformers import HGNetV2Config + +import mindspore +from mindspore import Parameter, Tensor, mint, nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils.backbone_utils import BackboneMixin + + +class HGNetV2PreTrainedModel(PreTrainedModel): + config: HGNetV2Config + base_model_prefix = "hgnetv2" + main_input_name = "pixel_values" + _no_split_modules = ["HGNetV2BasicLayer"] + + +class HGNetV2LearnableAffineBlock(nn.Cell): + def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0): + super().__init__() + self.scale = Parameter(Tensor([scale_value]), requires_grad=True) + self.bias = Parameter(Tensor([bias_value]), requires_grad=True) + + def construct(self, hidden_state: Tensor) -> Tensor: + hidden_state = self.scale * hidden_state + self.bias + return hidden_state + + +class HGNetV2ConvLayer(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + activation: str = "relu", + use_learnable_affine_block: bool = False, + ): + super().__init__() + self.convolution = mint.nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + groups=groups, + padding=(kernel_size - 1) // 2, + bias=False, + ) + self.normalization = mint.nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else mint.nn.Identity() + if activation and use_learnable_affine_block: + self.lab = HGNetV2LearnableAffineBlock() + else: + self.lab = mint.nn.Identity() + + def construct(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.lab(hidden_state) + return hidden_state + + +class HGNetV2ConvLayerLight(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, use_learnable_affine_block: bool = False): + super().__init__() + self.conv1 = HGNetV2ConvLayer( + in_channels, + out_channels, + kernel_size=1, + activation=None, + use_learnable_affine_block=use_learnable_affine_block, + ) + self.conv2 = HGNetV2ConvLayer( + out_channels, + out_channels, + kernel_size=kernel_size, + groups=out_channels, + use_learnable_affine_block=use_learnable_affine_block, + ) + + def construct(self, hidden_state: Tensor) -> Tensor: + hidden_state = self.conv1(hidden_state) + hidden_state = self.conv2(hidden_state) + return hidden_state + + +class HGNetV2Embeddings(nn.Cell): + def __init__(self, config: HGNetV2Config): + super().__init__() + + self.stem1 = HGNetV2ConvLayer( + config.stem_channels[0], + config.stem_channels[1], + kernel_size=3, + stride=2, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem2a = HGNetV2ConvLayer( + config.stem_channels[1], + config.stem_channels[1] // 2, + kernel_size=2, + stride=1, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem2b = HGNetV2ConvLayer( + config.stem_channels[1] // 2, + config.stem_channels[1], + kernel_size=2, + stride=1, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem3 = HGNetV2ConvLayer( + config.stem_channels[1] * 2, + config.stem_channels[1], + kernel_size=3, + stride=2, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem4 = HGNetV2ConvLayer( + config.stem_channels[1], + config.stem_channels[2], + kernel_size=1, + stride=1, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + + self.pool = mint.nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) + self.num_channels = config.num_channels + + def construct(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.stem1(pixel_values) + embedding = mint.nn.functional.pad(embedding, (0, 1, 0, 1)) + emb_stem_2a = self.stem2a(embedding) + emb_stem_2a = mint.nn.functional.pad(emb_stem_2a, (0, 1, 0, 1)) + emb_stem_2a = self.stem2b(emb_stem_2a) + pooled_emb = self.pool(embedding) + embedding = mint.cat([pooled_emb, emb_stem_2a], dim=1) + embedding = self.stem3(embedding) + embedding = self.stem4(embedding) + return embedding + + +class HGNetV2BasicLayer(nn.Cell): + def __init__( + self, + in_channels: int, + middle_channels: int, + out_channels: int, + layer_num: int, + kernel_size: int = 3, + residual: bool = False, + light_block: bool = False, + drop_path: float = 0.0, + use_learnable_affine_block: bool = False, + ): + super().__init__() + self.residual = residual + + self.layers = nn.CellList() + for i in range(layer_num): + temp_in_channels = in_channels if i == 0 else middle_channels + if light_block: + block = HGNetV2ConvLayerLight( + in_channels=temp_in_channels, + out_channels=middle_channels, + kernel_size=kernel_size, + use_learnable_affine_block=use_learnable_affine_block, + ) + else: + block = HGNetV2ConvLayer( + in_channels=temp_in_channels, + out_channels=middle_channels, + kernel_size=kernel_size, + use_learnable_affine_block=use_learnable_affine_block, + stride=1, + ) + self.layers.append(block) + + # feature aggregation + total_channels = in_channels + layer_num * middle_channels + aggregation_squeeze_conv = HGNetV2ConvLayer( + total_channels, + out_channels // 2, + kernel_size=1, + stride=1, + use_learnable_affine_block=use_learnable_affine_block, + ) + aggregation_excitation_conv = HGNetV2ConvLayer( + out_channels // 2, + out_channels, + kernel_size=1, + stride=1, + use_learnable_affine_block=use_learnable_affine_block, + ) + self.aggregation = nn.SequentialCell( + aggregation_squeeze_conv, + aggregation_excitation_conv, + ) + self.drop_path = mint.nn.Dropout(drop_path) if drop_path else mint.nn.Identity() + + def construct(self, hidden_state: Tensor) -> Tensor: + identity = hidden_state + output = [hidden_state] + for layer in self.layers: + hidden_state = layer(hidden_state) + output.append(hidden_state) + hidden_state = mint.cat(output, dim=1) + hidden_state = self.aggregation(hidden_state) + if self.residual: + hidden_state = self.drop_path(hidden_state) + identity + return hidden_state + + +class HGNetV2Stage(nn.Cell): + def __init__(self, config: HGNetV2Config, stage_index: int, drop_path: float = 0.0): + super().__init__() + in_channels = config.stage_in_channels[stage_index] + mid_channels = config.stage_mid_channels[stage_index] + out_channels = config.stage_out_channels[stage_index] + num_blocks = config.stage_num_blocks[stage_index] + num_layers = config.stage_numb_of_layers[stage_index] + downsample = config.stage_downsample[stage_index] + light_block = config.stage_light_block[stage_index] + kernel_size = config.stage_kernel_size[stage_index] + use_learnable_affine_block = config.use_learnable_affine_block + + if downsample: + self.downsample = HGNetV2ConvLayer( + in_channels, in_channels, kernel_size=3, stride=2, groups=in_channels, activation=None + ) + else: + self.downsample = mint.nn.Identity() + + blocks_list = [] + for i in range(num_blocks): + blocks_list.append( + HGNetV2BasicLayer( + in_channels if i == 0 else out_channels, + mid_channels, + out_channels, + num_layers, + residual=False if i == 0 else True, + kernel_size=kernel_size, + light_block=light_block, + drop_path=drop_path, + use_learnable_affine_block=use_learnable_affine_block, + ) + ) + self.blocks = nn.CellList(blocks_list) + + def construct(self, hidden_state: Tensor) -> Tensor: + hidden_state = self.downsample(hidden_state) + for block in self.blocks: + hidden_state = block(hidden_state) + return hidden_state + + +class HGNetV2Encoder(nn.Cell): + def __init__(self, config: HGNetV2Config): + super().__init__() + self.stages = nn.CellList([]) + for stage_index in range(len(config.stage_in_channels)): + resnet_stage = HGNetV2Stage(config, stage_index) + self.stages.append(resnet_stage) + + def construct( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class HGNetV2Backbone(HGNetV2PreTrainedModel, BackboneMixin): + def __init__(self, config: HGNetV2Config): + super().__init__(config) + super()._init_backbone(config) + self.depths = config.depths + self.num_features = [config.embedding_size] + config.hidden_sizes + self.embedder = HGNetV2Embeddings(config) + self.encoder = HGNetV2Encoder(config) + + # initialize weights and apply final processing + self.post_init() + + def construct( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import HGNetV2Config + >>> from mindone.transformers import HGNetV2Backbone + >>> import mindspore as ms + + >>> config = HGNetV2Config() + >>> model = HGNetV2Backbone(config) + + >>> pixel_values = ms.mint.randn(1, 3, 224, 224) + + >>> outputs = model(pixel_values) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output = self.embedder(pixel_values) + + outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +class HGNetV2ForImageClassification(HGNetV2PreTrainedModel): + def __init__(self, config: HGNetV2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.embedder = HGNetV2Embeddings(config) + self.encoder = HGNetV2Encoder(config) + self.avg_pool = mint.nn.AdaptiveAvgPool2d((1, 1)) + self.flatten = nn.Flatten() + self.fc = ( + mint.nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else mint.nn.Identity() + ) + + # classification head + self.classifier = nn.CellList([self.avg_pool, self.flatten]) + + # initialize weights and apply final processing + self.post_init() + + def construct( + self, + pixel_values: Optional[Tensor] = None, + labels: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + ```python + >>> import mindspore as ms + >>> import requests + >>> from mindone.transformers import HGNetV2ForImageClassification, AutoImageProcessor + >>> from PIL import Image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> model = HGNetV2ForImageClassification.from_pretrained("ustc-community/hgnet-v2") + >>> processor = AutoImageProcessor.from_pretrained("ustc-community/hgnet-v2") + + >>> inputs = processor(images=image, return_tensors="np") + >>> inputs = {k: ms.tensor(v) for k, v in inputs.items()} + >>> outputs = model(**inputs) + >>> outputs.logits.shape + (1, 2) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + embedding_output = self.embedder(pixel_values) + outputs = self.encoder(embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict) + last_hidden_state = outputs[0] + for layer in self.classifier: + last_hidden_state = layer(last_hidden_state) + logits = self.fc(last_hidden_state) + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == mindspore.long or labels.dtype == mindspore.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = mint.nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = mint.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = mint.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +__all__ = ["HGNetV2Backbone", "HGNetV2PreTrainedModel", "HGNetV2ForImageClassification"] diff --git a/tests/transformers_tests/models/hgnet_v2/__init__.py b/tests/transformers_tests/models/hgnet_v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py b/tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py new file mode 100644 index 0000000000..a38858391b --- /dev/null +++ b/tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py @@ -0,0 +1,223 @@ +"""Adapted from https://github.com/huggingface/transformers/tree/main/tests/models/hgnet_v2/test_modeling_hgnet_v2.py.""" + +# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map]. +# +# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective +# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically +# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. +# +# In cases where models have unique initialization procedures or require testing with specialized output formats, +# it is necessary to develop distinct, dedicated test cases. +import inspect + +import numpy as np +import pytest +import torch +from transformers import HGNetV2Config + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy + +# ms.nn.MaxPool2d does not support bf16 inputs +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3} +MODES = [1] + + +class HGNetV2ModelTester: + def __init__( + self, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[64, 128, 256, 512], + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + stem_channels=[3, 16, 16], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + ): + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.stage_in_channels = stage_in_channels + self.stage_mid_channels = stage_mid_channels + self.stage_out_channels = stage_out_channels + self.stage_num_blocks = stage_num_blocks + self.stage_downsample = stage_downsample + self.stage_light_block = stage_light_block + self.stage_kernel_size = stage_kernel_size + self.stage_numb_of_layers = stage_numb_of_layers + self.stem_channels = stem_channels + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + self.out_features = out_features + self.out_indices = out_indices + + def prepare_config_and_inputs(self): + pixel_values = floats_numpy([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_numpy([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return HGNetV2Config( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + stage_in_channels=self.stage_in_channels, + stage_mid_channels=self.stage_mid_channels, + stage_out_channels=self.stage_out_channels, + stage_num_blocks=self.stage_num_blocks, + stage_downsample=self.stage_downsample, + stage_light_block=self.stage_light_block, + stage_kernel_size=self.stage_kernel_size, + stage_numb_of_layers=self.stage_numb_of_layers, + stem_channels=self.stem_channels, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + out_features=self.out_features, + out_indices=self.out_indices, + ) + + +model_tester = HGNetV2ModelTester() +config, pixel_values, labels = model_tester.prepare_config_and_inputs() + + +TEST_CASES = [ + [ + "HGNetV2Backbone", + "transformers.HGNetV2Backbone", + "mindone.transformers.HGNetV2Backbone", + (config,), + {}, + (pixel_values, None), + {}, + { + "feature_maps": "feature_maps", + }, + ], + [ + "HGNetV2ForImageClassification", + "transformers.HGNetV2ForImageClassification", + "mindone.transformers.HGNetV2ForImageClassification", + (config,), + {}, + (pixel_values, labels), + {}, + { + "loss": "loss", + "logits": "logits", + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in TEST_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )