From 9ffb362c4b6dbc23174b81c68d87d0a0facae4bb Mon Sep 17 00:00:00 2001 From: Pig Date: Wed, 1 May 2024 23:45:16 +0800 Subject: [PATCH 01/13] Add implementation of FlashInternImage --- timm/models/flash_intern_image.py | 1733 +++++++++++++++++++++++++++++ 1 file changed, 1733 insertions(+) create mode 100644 timm/models/flash_intern_image.py diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py new file mode 100644 index 0000000000..a3f6d980e3 --- /dev/null +++ b/timm/models/flash_intern_image.py @@ -0,0 +1,1733 @@ +"""Flash Intern Image +A Pytorch Implementation of Flash Intern Image as decribed in: + +`InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` + - https://arxiv.org/pdf/2103.14030 + +`DCNv4` + - https://arxiv.org/pdf/2401.06197 + +Code/weights from https://github.com/OpenGVLab/DCNv4, original copyright/license info below +""" +# -------------------------------------------------------- +# Flash Intern Image +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +import torch +import torch.nn as nn +from torch.nn.init import xavier_uniform_, constant_ +from collections import OrderedDict +import torch.utils.checkpoint as checkpoint +from timm.models.layers import trunc_normal_, DropPath +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._registry import register_model, generate_default_cfgs +from ._builder import build_model_with_cfg +import torch.nn.functional as F +from ._manipulate import checkpoint_seq +from typing import Dict, Any +import warnings + +__all__ = ['FlashInternImage'] + +dcn_version = 'DCNv4' +try: + import DCNv4 +except ImportError: + dcn_version = 'DCNv3' + warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\ + By default using DCNv3 pure pytorch implementation instead, which will affect the performance.\ + Suggesting install DCNv4 by `pip install DCNv4`') + + +class to_channels_first(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 3, 1, 2) + + +class to_channels_last(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.permute(0, 2, 3, 1) + + +def build_norm_layer(dim, + norm_layer, + in_format='channels_last', + out_format='channels_last', + eps=1e-6): + layers = [] + if norm_layer == 'BN': + if in_format == 'channels_last': + layers.append(to_channels_first()) + layers.append(nn.BatchNorm2d(dim)) + if out_format == 'channels_last': + layers.append(to_channels_last()) + elif norm_layer == 'LN': + if in_format == 'channels_first': + layers.append(to_channels_last()) + layers.append(nn.LayerNorm(dim, eps=eps)) + if out_format == 'channels_first': + layers.append(to_channels_first()) + else: + raise NotImplementedError( + f'build_norm_layer does not support {norm_layer}') + return nn.Sequential(*layers) + + +def build_act_layer(act_layer): + if act_layer == 'ReLU': + return nn.ReLU(inplace=True) + elif act_layer == 'SiLU': + return nn.SiLU(inplace=True) + elif act_layer == 'GELU': + return nn.GELU() + + raise NotImplementedError(f'build_act_layer does not support {act_layer}') + + +def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): + _, H_, W_, _ = spatial_shapes + H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + + ref_y, ref_x = torch.meshgrid( + torch.linspace( + # pad_h + 0.5, + # H_ - pad_h - 0.5, + (dilation_h * (kernel_h - 1)) // 2 + 0.5, + (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, + H_out, + dtype=torch.float32, + device=device), + torch.linspace( + # pad_w + 0.5, + # W_ - pad_w - 0.5, + (dilation_w * (kernel_w - 1)) // 2 + 0.5, + (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, + W_out, + dtype=torch.float32, + device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + + ref = torch.stack((ref_x, ref_y), -1).reshape( + 1, H_out, W_out, 1, 2) + + return ref + + +def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): + _, H_, W_, _ = spatial_shapes + points_list = [] + x, y = torch.meshgrid( + torch.linspace( + -((dilation_w * (kernel_w - 1)) // 2), + -((dilation_w * (kernel_w - 1)) // 2) + + (kernel_w - 1) * dilation_w, kernel_w, + dtype=torch.float32, + device=device), + torch.linspace( + -((dilation_h * (kernel_h - 1)) // 2), + -((dilation_h * (kernel_h - 1)) // 2) + + (kernel_h - 1) * dilation_h, kernel_h, + dtype=torch.float32, + device=device)) + + points_list.extend([x / W_, y / H_]) + grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ + repeat(1, group, 1).permute(1, 0, 2) + grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) + + return grid + + +def dcnv3_core_pytorch( + input, offset, mask, kernel_h, + kernel_w, stride_h, stride_w, pad_h, + pad_w, dilation_h, dilation_w, group, + group_channels, offset_scale): + # for debug and test only, + # need to use cuda version instead + input = F.pad( + input, + [0, 0, pad_h, pad_h, pad_w, pad_w]) + N_, H_in, W_in, _ = input.shape + _, H_out, W_out, _ = offset.shape + + ref = _get_reference_points( + input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) + grid = _generate_dilation_grids( + input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) + spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ + repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) + + sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ + offset * offset_scale / spatial_norm + + P_ = kernel_h * kernel_w + sampling_grids = 2 * sampling_locations - 1 + # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in + input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ + reshape(N_*group, group_channels, H_in, W_in) + # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 + sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ + flatten(0, 1) + # N_*group, group_channels, H_out*W_out, P_ + sampling_input_ = F.grid_sample( + input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) + + # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) + mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ + reshape(N_*group, 1, H_out*W_out, P_) + output = (sampling_input_ * mask).sum(-1).view(N_, + group*group_channels, H_out*W_out) + + return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + + return (n & (n - 1) == 0) and n != 0 + + +class CenterFeatureScaleModule(nn.Module): + def forward(self, + query, + center_feature_scale_proj_weight, + center_feature_scale_proj_bias): + center_feature_scale = F.linear(query, + weight=center_feature_scale_proj_weight, + bias=center_feature_scale_proj_bias).sigmoid() + return center_feature_scale + + +class DCNv3_pytorch(nn.Module): + def __init__( + self, + channels=64, + kernel_size=3, + dw_kernel_size=None, + stride=1, + pad=1, + dilation=1, + group=4, + offset_scale=1.0, + act_layer='GELU', + norm_layer='LN', + center_feature_scale=False): + """ + DCNv3 Module + :param channels + :param kernel_size + :param stride + :param pad + :param dilation + :param group + :param offset_scale + :param act_layer + :param norm_layer + """ + super().__init__() + if channels % group != 0: + raise ValueError( + f'channels must be divisible by group, but got {channels} and {group}') + _d_per_group = channels // group + dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size + # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_group): + warnings.warn( + "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.offset_scale = offset_scale + self.channels = channels + self.kernel_size = kernel_size + self.dw_kernel_size = dw_kernel_size + self.stride = stride + self.dilation = dilation + self.pad = pad + self.group = group + self.group_channels = channels // group + self.offset_scale = offset_scale + self.center_feature_scale = center_feature_scale + + self.dw_conv = nn.Sequential( + nn.Conv2d( + channels, + channels, + kernel_size=dw_kernel_size, + stride=1, + padding=(dw_kernel_size - 1) // 2, + groups=channels), + build_norm_layer( + channels, + norm_layer, + 'channels_first', + 'channels_last'), + build_act_layer(act_layer)) + self.offset = nn.Linear( + channels, + group * kernel_size * kernel_size * 2) + self.mask = nn.Linear( + channels, + group * kernel_size * kernel_size) + self.input_proj = nn.Linear(channels, channels) + self.output_proj = nn.Linear(channels, channels) + self._reset_parameters() + + if center_feature_scale: + self.center_feature_scale_proj_weight = nn.Parameter( + torch.zeros((group, channels), dtype=torch.float)) + self.center_feature_scale_proj_bias = nn.Parameter( + torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) + self.center_feature_scale_module = CenterFeatureScaleModule() + + def _reset_parameters(self): + constant_(self.offset.weight.data, 0.) + constant_(self.offset.bias.data, 0.) + constant_(self.mask.weight.data, 0.) + constant_(self.mask.bias.data, 0.) + xavier_uniform_(self.input_proj.weight.data) + constant_(self.input_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, input): + """ + :param query (N, H, W, C) + :return output (N, H, W, C) + """ + N, H, W, _ = input.shape + + x = self.input_proj(input) + x_proj = x + + x1 = input.permute(0, 3, 1, 2) + x1 = self.dw_conv(x1) + offset = self.offset(x1) + mask = self.mask(x1).reshape(N, H, W, self.group, -1) + mask = F.softmax(mask, -1).reshape(N, H, W, -1) + + x = dcnv3_core_pytorch( + x, offset, mask, + self.kernel_size, self.kernel_size, + self.stride, self.stride, + self.pad, self.pad, + self.dilation, self.dilation, + self.group, self.group_channels, + self.offset_scale) + if self.center_feature_scale: + center_feature_scale = self.center_feature_scale_module( + x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) + # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels + center_feature_scale = center_feature_scale[..., None].repeat( + 1, 1, 1, 1, self.channels // self.group).flatten(-2) + x = x * (1 - center_feature_scale) + x_proj * center_feature_scale + x = self.output_proj(x) + + return x + +# --- DCNv3 pure pytorch implementation finished --- # +# --- FlashInternImage implementation start --- # +class CrossAttention(nn.Module): + r""" Cross Attention Module + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + attn_head_dim (int, optional): Dimension of attention head. + out_dim (int, optional): Dimension of output. + """ + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + attn_head_dim=None, + out_dim=None): + super().__init__() + if out_dim is None: + out_dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + assert all_head_dim == dim + + self.q = nn.Linear(dim, all_head_dim, bias=False) + self.k = nn.Linear(dim, all_head_dim, bias=False) + self.v = nn.Linear(dim, all_head_dim, bias=False) + + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, out_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, k=None, v=None): + B, N, C = x.shape + N_k = k.shape[1] + N_v = v.shape[1] + + q_bias, k_bias, v_bias = None, None, None + if self.q_bias is not None: + q_bias = self.q_bias + k_bias = self.k_bias + v_bias = self.v_bias + + q = F.linear(input=x, weight=self.q.weight, bias=q_bias) + q = q.reshape(B, N, 1, self.num_heads, + -1).permute(2, 0, 3, 1, + 4).squeeze(0) # (B, N_head, N_q, dim) + + k = F.linear(input=k, weight=self.k.weight, bias=k_bias) + k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, + 4).squeeze(0) + + v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, + 4).squeeze(0) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class AttentiveBlock(nn.Module): + r"""Attentive Block + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 8 + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop (float, optional): Dropout rate. Default: 0.0. + attn_drop (float, optional): Attention dropout rate. Default: 0.0. + drop_path (float | tuple[float], optional): Stochastic depth rate. + Default: 0.0. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm. + attn_head_dim (int, optional): Dimension of attention head. Default: None. + out_dim (int, optional): Dimension of output. Default: None. + """ + + def __init__(self, + dim, + num_heads, + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer="LN", + attn_head_dim=None, + out_dim=None): + super().__init__() + + self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6) + self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6) + self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6) + self.cross_dcn = CrossAttention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + attn_head_dim=attn_head_dim, + out_dim=out_dim) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, + x_q, + x_kv, + pos_q, + pos_k, + bool_masked_pos, + rel_pos_bias=None): + x_q = self.norm1_q(x_q + pos_q) + x_k = self.norm1_k(x_kv + pos_k) + x_v = self.norm1_v(x_kv) + + x = self.cross_dcn(x_q, k=x_k, v=x_v) + + return x + + +class AttentionPoolingBlock(AttentiveBlock): + + def forward(self, x): + x_q = x.mean(1, keepdim=True) + x_kv = x + pos_q, pos_k = 0, 0 + x = super().forward(x_q, x_kv, pos_q, pos_k, + bool_masked_pos=None, + rel_pos_bias=None) + x = x.squeeze(1) + return x + + +class StemLayer(nn.Module): + r""" Stem layer of InternImage + Args: + in_chans (int): number of input channels + out_chans (int): number of output channels + act_layer (str): activation layer + norm_layer (str): normalization layer + """ + + def __init__(self, + in_chans=3, + out_chans=96, + act_layer='GELU', + norm_layer='BN'): + super().__init__() + self.conv1 = nn.Conv2d(in_chans, + out_chans // 2, + kernel_size=3, + stride=2, + padding=1) + self.norm1 = build_norm_layer(out_chans // 2, norm_layer, + 'channels_first', 'channels_first') + self.act = build_act_layer(act_layer) + self.conv2 = nn.Conv2d(out_chans // 2, + out_chans, + kernel_size=3, + stride=2, + padding=1) + self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first', + 'channels_last') + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.act(x) + x = self.conv2(x) + x = self.norm2(x) + return x + + +class DownsampleLayer(nn.Module): + r""" Downsample layer of InternImage + Args: + channels (int): number of input channels + norm_layer (str): normalization layer + """ + + def __init__(self, channels, norm_layer='LN'): + super().__init__() + self.conv = nn.Conv2d(channels, + 2 * channels, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.norm = build_norm_layer(2 * channels, norm_layer, + 'channels_first', 'channels_first') + + + def forward(self, x, shape=None): + H, W = shape + N, HW, C = x.shape + x = x.view(N, H, W, C) + x = self.conv(x.permute(0, 3, 1, 2)) + x = self.norm(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).permute(0, 2, 1) + + return x, (H, W) + + +class MLPLayer(nn.Module): + r""" MLP layer of InternImage + Args: + in_features (int): number of input features + hidden_features (int): number of hidden features + out_features (int): number of output features + act_layer (str): activation layer + drop (float): dropout rate + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer='GELU', + mlp_fc2_bias=False, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + self.act = build_act_layer(act_layer) + self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias) + self.drop = nn.Dropout(drop) + + + def forward(self, x, shape, level_idx=0): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class InternImageLayer(nn.Module): + r""" Basic layer of InternImage + Args: + core_op (str): core operation of InternImage + channels (int): number of input channels + groups (int): Groups of each block. + mlp_ratio (float): ratio of mlp hidden features to input channels, Default: 4. + drop (float): dropout rate, Default: 0. + drop_path (float): drop path rate, Default: 0. + act_layer (str): activation layer, Default: 'GELU'. + norm_layer (str): normalization layer, Default: 'LN'. + post_norm (bool): whether to use post normalization, Default: False. + layer_scale (float): layer scale, Default: None. + offset_scale (float): offset scale, Default: 1.0. + with_cp (bool): whether to use checkpoint, Default: False. + dcn_output_bias (bool): whether to use dcn output bias, Default: False. + mlp_fc2_bias (bool): whether to use mlp fc2 bias, Default: False. + dw_kernel_size (int): Size of the dwconv, Default: None. + res_post_norm (bool): whether to use res post normalization, Default: False. + center_feature_scale (bool): whether to use center feature scale, Default: False. + """ + + def __init__(self, + core_op, + channels, + groups, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer='GELU', + norm_layer='LN', + post_norm=False, + layer_scale=None, + offset_scale=1.0, + with_cp=False, + dcn_output_bias=False, + mlp_fc2_bias=False, + dw_kernel_size=None, # for InternImage-H/G + res_post_norm=False, # for InternImage-H/G + center_feature_scale=False): # for InternImage-H/G + super().__init__() + self.channels = channels + self.groups = groups + self.mlp_ratio = mlp_ratio + self.with_cp = with_cp + + self.norm1 = build_norm_layer(channels, 'LN') + self.post_norm = post_norm + if dcn_version == 'DCNv4' and core_op == 'DCNv4': + self.dcn = DCNv4.DCNv4( + channels=channels, + group=groups, + offset_scale=offset_scale, + dw_kernel_size=dw_kernel_size, + output_bias=dcn_output_bias, + ) + else: + self.dcn = DCNv3_pytorch( + channels=channels, + group=groups, + offset_scale=offset_scale, + dw_kernel_size=dw_kernel_size, + center_feature_scale=center_feature_scale + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. \ + else nn.Identity() + self.norm2 = build_norm_layer(channels, 'LN') + self.mlp = MLPLayer(in_features=channels, + hidden_features=int(channels * mlp_ratio), + act_layer=act_layer, + drop=drop, + mlp_fc2_bias=mlp_fc2_bias + ) + self.layer_scale = layer_scale is not None + if self.layer_scale: + self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels), + requires_grad=True) + self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels), + requires_grad=True) + self.res_post_norm = res_post_norm + if res_post_norm: + self.res_post_norm1 = build_norm_layer(channels, 'LN') + self.res_post_norm2 = build_norm_layer(channels, 'LN') + def forward(self, x, shape, level_idx=0): + + def _inner_forward(x, shape, level_idx): + if not self.layer_scale: + if self.post_norm: + x = x + self.drop_path(self.norm1(self.dcn(x, shape, level_idx))) + x = x + self.drop_path(self.norm2(self.mlp(x, shape, level_idx))) + elif self.res_post_norm: # for InternImage-H/G + x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x), shape, level_idx))) + x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x), shape, level_idx))) + + else: + x = x + self.drop_path(self.dcn(self.norm1(x), shape, level_idx)) + x = x + self.drop_path(self.mlp(self.norm2(x), shape, level_idx)) + return x + if self.post_norm: + x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x, shape))) + x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x, shape, level_idx))) + else: + x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x), shape)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x), shape, level_idx)) + return x + + if self.with_cp and x.requires_grad: + x = checkpoint.checkpoint(_inner_forward, x, shape, level_idx) + else: + x = _inner_forward(x, shape, level_idx) + + return x + + +class InternImageBlock(nn.Module): + r""" Block of InternImage + Args: + core_op (str): core operation of InternImage + channels (int): number of input channels + depth (int): Depth of each block. + groups (int): Groups of each block. + downsample (bool): Whether to use downsample, Default: True. + downsample_layer (nn.Module): Downsample layer, Default: DownsampleLayer. + mlp_ratio (float): ratio of mlp hidden features to input channels, Default: 4. + drop (float): dropout rate, Default: 0. + drop_path (float): drop path rate, Default: 0. + act_layer (str): activation layer, Default: 'GELU'. + norm_layer (str): normalization layer, Default: 'LN'. + post_norm (bool): whether to use post normalization, Default: False. + offset_scale (float): offset scale, Default: 0.5. + layer_scale (float): layer scale, Default: None. + with_cp (bool): whether to use checkpoint, Default: False. + dcn_output_bias (bool): whether to use dcn output bias, Default: False. + mlp_fc2_bias (bool): whether to use mlp fc2 bias, Default: False. + dw_kernel_size (int): Size of the dwconv, Default: None. + post_norm_block_ids (list): block ids for post normalization, Default: None. + res_post_norm (bool): whether to use res post normalization, Default: False. + center_feature_scale (bool): whether to use center feature scale, Default: False. + """ + + def __init__(self, + core_op, + channels, + depth, + groups, + downsample=True, + downsample_layer=DownsampleLayer, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer='GELU', + norm_layer='LN', + post_norm=False, + offset_scale=0.5, + layer_scale=None, + with_cp=False, + dcn_output_bias=False, + mlp_fc2_bias=False, + dw_kernel_size=None, # for InternImage-H/G + post_norm_block_ids=None, # for InternImage-H/G + res_post_norm=False, # for InternImage-H/G + center_feature_scale=False): # for InternImage-H/G + super().__init__() + self.channels = channels + self.depth = depth + self.post_norm = post_norm + self.center_feature_scale = center_feature_scale + + self.blocks = nn.ModuleList([ + InternImageLayer( + core_op=core_op, + channels=channels, + groups=groups, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + post_norm=post_norm, + layer_scale=layer_scale, + offset_scale=offset_scale, + with_cp=with_cp, + dcn_output_bias=dcn_output_bias, + mlp_fc2_bias=mlp_fc2_bias, + dw_kernel_size=dw_kernel_size, # for InternImage-H/G + res_post_norm=res_post_norm, # for InternImage-H/G + center_feature_scale=center_feature_scale # for InternImage-H/G + ) for i in range(depth) + ]) + if not self.post_norm or center_feature_scale: + self.norm = build_norm_layer(channels, 'LN') + self.post_norm_block_ids = post_norm_block_ids + if post_norm_block_ids is not None: # for InternImage-H/G + self.post_norms = nn.ModuleList( + [build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids] + ) + self.downsample = downsample_layer( + channels=channels, norm_layer=norm_layer) if downsample else None + + + def forward(self, x, return_wo_downsample=False, shape=None, level_idx=0 + ): + for i, blk in enumerate(self.blocks): + if self.grad_checkpoint and not torch.jit.is_scripting(): + x = checkpoint_seq(blk, x) + else: + x = blk(x, shape=shape, level_idx=level_idx) + if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids): + index = self.post_norm_block_ids.index(i) + x = self.post_norms[index](x) # for InternImage-H/G + if not self.post_norm or self.center_feature_scale: + x = self.norm(x) + if return_wo_downsample: + x_ = x.clone() + if self.downsample is not None: + x, shape = self.downsample(x, shape=shape) + + if return_wo_downsample: + return x, x_, shape + return x, shape + + +class FlashInternImage(nn.Module): + r""" FlashInternImage + A PyTorch impl based on : + `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` - + https://arxiv.org/pdf/2103.14030 + `DCNv4` - https://arxiv.org/pdf/2401.06197 + Args: + core_op (str): Core operator. Default: 'DCNv4' + channels (int): Number of the first stage. Default: 64 + depths (list): Depth of each block. Default: [3, 4, 18, 5] + groups (list): Groups of each block. Default: [3, 6, 12, 24] + num_classes (int): Number of classes. Default: 1000 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop_rate (float): Probability of an element to be zeroed. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2 + drop_path_type (str): Drop path type. Default: 'linear' + act_layer (str): Activation layer. Default: 'GELU' + norm_layer (str): Normalization layer. Default: 'LN' + layer_scale (float): Layer scale. Default: None + offset_scale (float): Offset scale. Default: 0.5 + post_norm (bool): Whether to use post norm. Default: False + cls_scale (float): Class scale. Default: 1.5 + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + mlp_fc2_bias (bool): Whether to use mlp fc2 bias. Default: False + dcn_output_bias (bool): Whether to use dcn output bias. Default: False + dw_kernel_size (int): Size of the dwconv. Default: None + use_clip_projector (bool): Whether to use clip projector. Default: False + level2_post_norm (bool): Whether to use level2 post norm. Default: False + level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None + res_post_norm (bool): Whether to use res post norm. Default: False + center_feature_scale (bool): Whether to use center feature scale. Default: False + out_indices (tuple): Output from which stages. Default: (0, 1, 2, 3) + """ + + def __init__(self, + core_op='DCNv4', + channels=64, + depths=[3, 4, 18, 5], + groups=[3, 6, 12, 24], + num_classes=1000, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.2, + drop_path_type='linear', + act_layer='GELU', + norm_layer='LN', + layer_scale=None, + offset_scale=0.5, + post_norm=False, + cls_scale=1.5, + with_cp=False, + mlp_fc2_bias=False, + dcn_output_bias=False, + dw_kernel_size=None, + use_clip_projector=False, # for InternImage-H/G + level2_post_norm=False, # for InternImage-H/G + level2_post_norm_block_ids=None, # for InternImage-H/G + res_post_norm=False, # for InternImage-H/G + center_feature_scale=False, # for InternImage-H/G + out_indices=(0, 1, 2, 3), + **kwargs): + super().__init__() + self.core_op = core_op + self.num_classes = num_classes + self.num_levels = len(depths) + self.depths = depths + self.channels = channels + self.num_features = int(channels * 2**(self.num_levels - 1)) + self.post_norm = post_norm + self.mlp_ratio = mlp_ratio + self.use_clip_projector = use_clip_projector + self.level2_post_norm_block_ids = level2_post_norm_block_ids + self.out_indices = out_indices + print(f'using core type: {core_op}') + print(f'using activation layer: {act_layer}') + print(f'using main norm layer: {norm_layer}') + print(f'using dpr: {drop_path_type}, {drop_path_rate}') + print(f"level2_post_norm: {level2_post_norm}") + print(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}") + print(f"res_post_norm: {res_post_norm}") + + in_chans = 3 + self.patch_embed = StemLayer(in_chans=in_chans, + out_chans=channels, + act_layer=act_layer, + norm_layer=norm_layer) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] + if drop_path_type == 'uniform': + for i in range(len(dpr)): + dpr[i] = drop_path_rate + + self.levels = nn.ModuleList() + for i in range(self.num_levels): + post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( + i == 2) else None # for InternImage-H/G + + level = InternImageBlock( + core_op=core_op, + channels=int(channels * 2**i), + depth=depths[i], + groups=groups[i], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], + act_layer=act_layer, + norm_layer=norm_layer, + post_norm=post_norm, + downsample=(i < self.num_levels - 1), + downsample_layer = DownsampleLayer, + layer_scale=layer_scale, + offset_scale=offset_scale, + with_cp=with_cp, + mlp_fc2_bias=mlp_fc2_bias, + dcn_output_bias=dcn_output_bias, + dw_kernel_size=dw_kernel_size, # for InternImage-H/G + post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G + res_post_norm=res_post_norm, # for InternImage-H/G + center_feature_scale=center_feature_scale # for InternImage-H/G + ) + self.levels.append(level) + + if not use_clip_projector: # for InternImage-T/S/B/L/XL + self.conv_head = nn.Sequential( + nn.Conv2d(self.num_features, + int(self.num_features * cls_scale), + kernel_size=1, + bias=False), + build_norm_layer(int(self.num_features * cls_scale), 'BN', + 'channels_first', 'channels_first'), + build_act_layer(act_layer)) + self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \ + if num_classes > 0 else nn.Identity() + else: # for InternImage-H/G + pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768 + self.dcnv3_head_x4 = nn.Sequential( + nn.Conv2d(in_channels=self.num_features, + out_channels=pretrain_embed_dim * (_stride ** 2), + kernel_size=1), nn.PixelShuffle(_stride)) + self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2, + out_channels=pretrain_embed_dim, + kernel_size=1) + self.clip_projector = AttentionPoolingBlock( + dim=pretrain_embed_dim, + num_heads=attnpool_num_heads, + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + norm_layer=norm_layer, + out_dim=clip_embed_dim) + self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6) + self.head = nn.Linear( + clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.num_layers = len(depths) + self.apply(self._init_weights) + self.apply(self._init_deform_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _init_deform_weights(self, m): + if isinstance(m, getattr(DCNv4, self.core_op)): + m._reset_parameters() + elif isinstance(m, DCNv3_pytorch): + m._reset_parameters() + + def init_weights(self): + self.apply(self._init_weights) + self.apply(self._init_deform_weights) + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) \ + if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^patch_embed', # stem and embed + blocks=[(r'^levels\.(\d+)', None)] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.levels: + l.grad_checkpointing = enable + + @torch.jit.ignore + def lr_decay_keywards(self, decay_ratio=0.87): + lr_ratios = {} + + # blocks + idx = 0 + for i in range(4): + layer_num = 3 - i # 3 2 1 0 + for j in range(self.depths[layer_num]): + block_num = self.depths[layer_num] - j - 1 + tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num) + decay = 1.0 * (decay_ratio**idx) + lr_ratios[tag] = decay + idx += 1 + # patch_embed (before stage-1) + lr_ratios["patch_embed"] = lr_ratios['levels.0.blocks.0.'] + # levels.0.downsample (between stage-1 and stage-2) + lr_ratios["levels.0.downsample"] = lr_ratios['levels.1.blocks.0.'] + lr_ratios["levels.0.norm"] = lr_ratios['levels.1.blocks.0.'] + # levels.1.downsample (between stage-2 and stage-3) + lr_ratios["levels.1.downsample"] = lr_ratios['levels.2.blocks.0.'] + lr_ratios["levels.1.norm"] = lr_ratios['levels.2.blocks.0.'] + # levels.2.downsample (between stage-3 and stage-4) + lr_ratios["levels.2.downsample"] = lr_ratios['levels.3.blocks.0.'] + lr_ratios["levels.2.norm"] = lr_ratios['levels.3.blocks.0.'] + return lr_ratios + + def forward_features_no_clip_projector(self, x): + x = self.patch_embed(x) + N, H, W, C = x.shape + x = x.view(N, H*W, C) + + shape=(H, W) + seq_out = [] + for level_idx, level in enumerate(self.levels): + old_shape = shape + x, shape = level(x, shape=shape) + h, w = shape + x = x.view(N, h, w, -1) + x = self.conv_head(x.permute(0, 3, 1, 2)) + x = self.avgpool(x) + x = torch.flatten(x, 1) + return x + + def forward_features_seq_out(self, x): # for detection or segmentation + x = self.patch_embed(x) + N, H, W, C = x.shape + x = x.view(N, H*W, C) + shape=(H, W) + seq_out = [] + for level_idx, level in enumerate(self.levels): + old_shape = shape + x, x_ , shape = level(x, return_wo_downsample=True, shape=shape, level_idx=level_idx) + h, w= old_shape + seq_out.append(x_.reshape(N, h, w, -1).permute(0, 3, 1, 2)) + return seq_out + + def forward_clip_projector(self, x): # for InternImage-H/G + xs = self.forward_features_seq_out(x) + x1, x2, x3, x4 = xs + + x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW + x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW + x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW + x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW + + x4 = self.dcnv3_head_x4(x4) + x = x4 + x3 = self.dcnv3_head_x3(x3) + x = x + x3 + + x = x.flatten(-2).transpose(1, 2).contiguous() + x = self.clip_projector(x) + x = self.fc_norm(x) + + return x + + def forward_features(self, x): + if self.use_clip_projector: # for InternImage-H/G + x = self.forward_clip_projector(x) + else: # for InternImage-T/S/B/L/XL + x = self.forward_features_no_clip_projector(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ process different state_dict format from different pretaied models """ + if 'model' in state_dict: + _state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + _state_dict = state_dict['state_dict'] + else: + raise ValueError('Unrecognized state_dict format') + + state_dict = OrderedDict() + for k, v in _state_dict.items(): + if k.startswith('backbone.'): + k = k[9:] + state_dict[k] = v + + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + return state_dict + + +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv1', + 'classifier': 'head', + **kwargs, + } + + +default_cfgs = generate_default_cfgs({ + 'flash_intern_image_tiny.224_in1k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_t_1k_224.pth', + hf_hub_id='timm/', + ), + 'flash_intern_image_small.224_in1k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_s_1k_224.pth', + hf_hub_id='timm/', + ), + 'flash_intern_image_base.224_in1k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_b_1k_224.pth', + hf_hub_id='timm/', + ), + 'flash_intern_image_large.384_in22k_ft_1k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_l_22kto1k_384.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), + ), + 'flash_intern_image_large.384_in22k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_l_22k_384.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=21841, + ), + 'cascade_flash_intern_image_large.fpn_1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/cascade_flash_internimage_l_fpn_1x_coco.pth', + hf_hub_id='timm/', + ), + 'cascade_flash_intern_image_large.fpn_3x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/cascade_flash_internimage_l_fpn_3x_coco.pth', + hf_hub_id='timm/', + ), + 'dino_4scale_flash_intern_image_tiny.1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_t_1x_coco.pth', + hf_hub_id='timm/', + ), + 'dino_4scale_flash_intern_image_small.1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_s_1x_coco.pth', + hf_hub_id='timm/', + ), + 'dino_4scale_flash_intern_image_base.1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_b_1x_coco.pth', + hf_hub_id='timm/', + ), + 'dino_4scale_flash_intern_image_large.1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_l_1x_coco.pth', + hf_hub_id='timm/', + img_size=(3, 384, 384), + num_classes=21841, + ), + 'mask_rcnn_flash_intern_image_tiny.fpn_1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_t_fpn_1x_coco.pth', + hf_hub_id='timm/', + ), + 'mask_rcnn_flash_intern_image_tiny.fpn_3x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_s_fpn_3x_coco.pth', + hf_hub_id='timm/', + ), + 'mask_rcnn_flash_intern_image_small.fpn_1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_s_fpn_1x_coco.pth', + hf_hub_id='timm/', + ), + 'mask_rcnn_flash_intern_image_small.fpn_3x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_s_fpn_3x_coco.pth', + hf_hub_id='timm/', + ), + 'mask_rcnn_flash_intern_image_base.fpn_1x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_b_fpn_1x_coco.pth', + hf_hub_id='timm/', + ), + 'mask_rcnn_flash_intern_image_base.fpn_3x_coco': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_b_fpn_3x_coco.pth', + hf_hub_id='timm/', + ), + 'mask2former_flash_intern_image_tiny.512_160k_ade20k_ss': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_t_512_160k_ade20k_ss.pth', + hf_hub_id='timm/', + img_size=(3, 512, 512), + ), + 'mask2former_flash_intern_image_small.640_160k_ade20k_ss': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_s_640_160k_ade20k_ss.pth', + hf_hub_id='timm/', + img_size=(3, 640, 640), + ), + 'mask2former_flash_intern_image_base.640_160k_ade20k_ss': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_b_640_160k_ade20k_ss.pth', + hf_hub_id='timm/', + img_size=(3, 640, 640), + ), + 'mask2former_flash_intern_image_large.640_160k_ade20k_ss': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_l_640_160k_ade20k_ss.pth', + hf_hub_id='timm/', + img_size=(3, 640, 640), + ), + 'upernet_flash_intern_image_tiny.512_160k_ade20k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_t_512_160k_ade20k.pth', + hf_hub_id='timm/', + img_size=(3, 512, 512), + ), + 'upernet_flash_intern_image_small.512_160k_ade20k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_s_512_160k_ade20k.pth', + hf_hub_id='timm/', + img_size=(3, 512, 512), + ), + 'upernet_flash_intern_image_base.512_160k_ade20k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_b_512_160k_ade20k.pth', + hf_hub_id='timm/', + img_size=(3, 512, 512), + ), + 'upernet_flash_intern_image_large.640_160k_ade20k': _cfg( + url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_l_640_160k_ade20k.pth', + hf_hub_id='timm/', + img_size=(3, 640, 640), + ), +}) + + +def _create_flash_intern_image(variant: str, pretrained: bool = False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + return build_model_with_cfg( + FlashInternImage, + variant, + pretrained=pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs + ) + +@register_model +def flash_intern_image_tiny(pretrained=False, **kwarg): + """ + FlashInternImage-T, trained on ImageNet-1k, for classification. + """ + pretrained = pretrained and dcn_version == 'DCNv4' + model_arg = dict( + core_op='DCNv4', + channels=64, + depths=[4, 4, 18, 4], + groups=[4, 8, 16, 32], + offset_scale=1., + mlp_ratio=4., + drop_path_rate=0.1, + ) + return _create_flash_intern_image('flash_intern_image_tiny', pretrained=pretrained, **dict(model_arg, **kwarg)) + + +@register_model +def flash_intern_image_small(pretrained=False, **kwarg): + """ + FlashInternImage-S, trained on ImageNet-1k, for classification. + """ + pretrained = pretrained and dcn_version == 'DCNv4' + model_arg = dict( + core_op='DCNv4', + channels=80, + depths=[4, 4, 21, 4], + groups=[5, 10, 20, 40], + layer_scale=1e-5, + offset_scale=1., + mlp_ratio=4., + drop_path_rate=0.4, + post_norm=True, + dw_kernel_size=3, + ) + return _create_flash_intern_image('flash_intern_image_small', pretrained=pretrained, **dict(model_arg, **kwarg)) + + +@register_model +def flash_intern_image_base(pretrained=False, **kwarg): + """ + FlashInternImage-B, trained on ImageNet-1k, for classification. + """ + pretrained = pretrained and dcn_version == 'DCNv4' + model_arg = dict( + core_op='DCNv4', + channels=112, + depths=[4, 4, 21, 4], + groups=[7, 14, 28, 56], + layer_scale=1e-5, + offset_scale=0.5, + mlp_ratio=4., + drop_path_rate=0.5, + post_norm=True, + dw_kernel_size=3, + ) + return _create_flash_intern_image('flash_intern_image_base', pretrained=pretrained, **dict(model_arg, **kwarg)) + + +@register_model +def flash_intern_image_large(pretrained=False, **kwarg): + """ + FlashInternImage-L, trained on ImageNet-1k, for classification. + """ + pretrained = pretrained and dcn_version == 'DCNv4' + model_arg = dict( + core_op='DCNv4', + channels=160, + depths=[5, 5, 22, 5], + groups=[10, 20, 40, 80], + layer_scale=1e-5, + offset_scale=2., + mlp_ratio=4., + drop_path_rate=0.1, + post_norm=True, + dw_kernel_size=3, + dcn_output_bias=True, + mlp_fc2_bias=True, + ) + return _create_flash_intern_image('flash_intern_image_large', pretrained=pretrained, **dict(model_arg, **kwarg)) + + +@register_model +def cascade_flash_intern_image_large(pretrained=False, **kwargs): + """ + CascadeFlashInternImage-L, trained on COCO, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=160, + depths=[5, 5, 22, 5], + groups=[10, 20, 40, 80], + layer_scale=1., + offset_scale=2., + mlp_ratio=4., + drop_path_rate=0.4, + post_norm=True, + dw_kernel_size=3, + dcn_output_bias=True, + mlp_fc2_bias=True, + out_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('cascade_flash_intern_image_large', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def dino_4scale_flash_intern_image_tiny(pretrained=False, **kwargs): + """ + FlashInternImage-T, trained on ImageNet-1K, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=64, + depths=[4, 4, 18, 4], + groups=[4, 8, 16, 32], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.2, + pose_norm=False, + with_cp=True, + output_indices=(1, 2, 3), + ) + return _create_flash_intern_image('dino_4scale_flash_intern_image_tiny', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def dino_4scale_flash_intern_image_small(pretrained=False, **kwargs): + """ + FlashInternImage-S, trained on ImageNet-1K, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=80, + depths=[4, 4, 21, 4], + groups=[5, 10, 20, 40], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + output_indices=(1, 2, 3), + ) + return _create_flash_intern_image('dino_4scale_flash_intern_image_small', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def dino_4scale_flash_intern_image_base(pretrained=False, **kwargs): + """ + FlashInternImage-B, trained on ImageNet-1K, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=112, + depths=[4, 4, 21, 4], + groups=[7, 14, 28, 56], + offset_scale=0.5, + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + output_indices=(1, 2, 3), + ) + return _create_flash_intern_image('dino_4scale_flash_intern_image_base', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def dino_4scale_flash_intern_image_large(pretrained=False, **kwargs): + """ + FlashInternImage-L, trained on ImageNet-22K, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=160, + depths=[5, 5, 22, 5], + groups=[10, 20, 40, 80], + offset_scale=2., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.4, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + dcn_output_bias=True, + mlp_fc2_bias=True, + output_indices=(1, 2, 3), + ) + return _create_flash_intern_image('dino_4scale_flash_intern_image_large', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask_rcnn_flash_intern_image_tiny(pretrained=False, **kwargs): + """ + FlashInternImage-T, trained on COCO, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=64, + depths=[4, 4, 18, 4], + groups=[4, 8, 16, 32], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.2, + pose_norm=False, + with_cp=True, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask_rcnn_flash_intern_image_tiny', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask_rcnn_flash_intern_image_small(pretrained=False, **kwargs): + """ + FlashInternImage-S, trained on COCO, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=80, + depths=[4, 4, 21, 4], + groups=[5, 10, 20, 40], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask_rcnn_flash_intern_image_small', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask_rcnn_flash_intern_image_base(pretrained=False, **kwargs): + """ + FlashInternImage-B, trained on COCO, used as backbone for detection. + """ + model_arg = dict( + core_op='DCNv4', + channels=112, + depths=[4, 4, 21, 4], + groups=[7, 14, 28, 56], + offset_scale=0.5, + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask_rcnn_flash_intern_image_base', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask2former_flash_intern_image_tiny(pretrained=False, **kwargs): + """ + FlashInternImage-T, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=64, + depths=[4, 4, 18, 4], + groups=[4, 8, 16, 32], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.2, + pose_norm=False, + with_cp=False, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask2former_flash_intern_image_tiny', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask2former_flash_intern_image_small(pretrained=False, **kwargs): + """ + FlashInternImage-S, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=80, + depths=[4, 4, 21, 4], + groups=[5, 10, 20, 40], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=False, + dw_kernel_size=3, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask2former_flash_intern_image_small', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask2former_flash_intern_image_base(pretrained=False, **kwargs): + """ + FlashInternImage-B, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=112, + depths=[4, 4, 21, 4], + groups=[7, 14, 28, 56], + offset_scale=0.5, + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.4, + pose_norm=True, + with_cp=False, + dw_kernel_size=3, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask2former_flash_intern_image_base', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def mask2former_flash_intern_image_large(pretrained=False, **kwargs): + """ + FlashInternImage-L, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=160, + depths=[5, 5, 22, 5], + groups=[10, 20, 40, 80], + offset_scale=2., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.5, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + dcn_output_bias=True, + mlp_fc2_bias=True, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('mask2former_flash_intern_image_large', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def upernet_flash_intern_image_tiny(pretrained=False, **kwargs): + """ + FlashInternImage-T, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=64, + depths=[4, 4, 18, 4], + groups=[4, 8, 16, 32], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.2, + pose_norm=False, + with_cp=True, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('upernet_flash_intern_image_tiny', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def upernet_flash_intern_image_small(pretrained=False, **kwargs): + """ + FlashInternImage-S, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=80, + depths=[4, 4, 21, 4], + groups=[5, 10, 20, 40], + offset_scale=1., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=True, + dw_kernel_size=3, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('upernet_flash_intern_image_small', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def upernet_flash_intern_image_base(pretrained=False, **kwargs): + """ + FlashInternImage-B, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=112, + depths=[4, 4, 21, 4], + groups=[7, 14, 28, 56], + offset_scale=0.5, + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.3, + pose_norm=True, + with_cp=False, + dw_kernel_size=3, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('upernet_flash_intern_image_base', pretrained=pretrained, **dict(model_arg, **kwargs)) + + +@register_model +def upernet_flash_intern_image_large(pretrained=False, **kwargs): + """ + FlashInternImage-L, trained on ADE20K, used as backbone for segmentation. + """ + model_arg = dict( + core_op='DCNv4', + channels=160, + depths=[5, 5, 22, 5], + groups=[10, 20, 40, 80], + offset_scale=2., + layer_scale=1., + mlp_ratio=4., + drop_path_rate=0.4, + pose_norm=True, + with_cp=False, + dw_kernel_size=3, + dcn_output_bias=True, + mlp_fc2_bias=True, + output_indices=(0, 1, 2, 3), + ) + return _create_flash_intern_image('upernet_flash_intern_image_large', pretrained=pretrained, **dict(model_arg, **kwargs)) From 5dc0fece92dc253b3e07e56fe329b5cd18b20446 Mon Sep 17 00:00:00 2001 From: Pig Date: Thu, 2 May 2024 02:26:26 +0800 Subject: [PATCH 02/13] Fix bugs of implementation of FlashInternImage --- timm/models/__init__.py | 1 + timm/models/flash_intern_image.py | 79 ++++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 23 deletions(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 9d09efac92..84e1c72561 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -21,6 +21,7 @@ from .efficientvit_msra import * from .eva import * from .fastvit import * +from .flash_intern_image import * from .focalnet import * from .gcvit import * from .ghostnet import * diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index a3f6d980e3..a5b299ba67 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -27,16 +27,19 @@ from ._manipulate import checkpoint_seq from typing import Dict, Any import warnings +import logging __all__ = ['FlashInternImage'] +_logger = logging.getLogger(__name__) + dcn_version = 'DCNv4' try: import DCNv4 except ImportError: dcn_version = 'DCNv3' - warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\ - By default using DCNv3 pure pytorch implementation instead, which will affect the performance.\ + warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\n\ + By default using DCNv3 pure pytorch implementation instead, which will affect the performance.\n\ Suggesting install DCNv4 by `pip install DCNv4`') @@ -897,6 +900,11 @@ def __init__(self, out_indices=(0, 1, 2, 3), **kwargs): super().__init__() + if dcn_version == 'DCNv4' and core_op == 'DCNv4': + core_op = 'DCNv4' + else: + warnings.warn('DCNv4 is not installed, use DCNv3 instead') + core_op = 'DCNv3' self.core_op = core_op self.num_classes = num_classes self.num_levels = len(depths) @@ -908,13 +916,13 @@ def __init__(self, self.use_clip_projector = use_clip_projector self.level2_post_norm_block_ids = level2_post_norm_block_ids self.out_indices = out_indices - print(f'using core type: {core_op}') - print(f'using activation layer: {act_layer}') - print(f'using main norm layer: {norm_layer}') - print(f'using dpr: {drop_path_type}, {drop_path_rate}') - print(f"level2_post_norm: {level2_post_norm}") - print(f"level2_post_norm_block_ids: {level2_post_norm_block_ids}") - print(f"res_post_norm: {res_post_norm}") + _logger.info(f'use core type: {core_op}') + _logger.info(f'using activation layer: {act_layer}') + _logger.info(f'using main norm layer: {norm_layer}') + _logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}') + _logger.info(f'level2_post_norm: {level2_post_norm}') + _logger.info(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}') + _logger.info(f'res_post_norm: {res_post_norm}') in_chans = 3 self.patch_embed = StemLayer(in_chans=in_chans, @@ -1008,7 +1016,7 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) def _init_deform_weights(self, m): - if isinstance(m, getattr(DCNv4, self.core_op)): + if dcn_version == 'DCNv4' and isinstance(m, getattr(DCNv4, self.core_op)): m._reset_parameters() elif isinstance(m, DCNv3_pytorch): m._reset_parameters() @@ -1213,7 +1221,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'dino_4scale_flash_intern_image_large.1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_l_1x_coco.pth', hf_hub_id='timm/', - img_size=(3, 384, 384), + input_size=(3, 384, 384), num_classes=21841, ), 'mask_rcnn_flash_intern_image_tiny.fpn_1x_coco': _cfg( @@ -1243,42 +1251,42 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'mask2former_flash_intern_image_tiny.512_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_t_512_160k_ade20k_ss.pth', hf_hub_id='timm/', - img_size=(3, 512, 512), + input_size=(3, 512, 512), ), 'mask2former_flash_intern_image_small.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_s_640_160k_ade20k_ss.pth', hf_hub_id='timm/', - img_size=(3, 640, 640), + input_size=(3, 640, 640), ), 'mask2former_flash_intern_image_base.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_b_640_160k_ade20k_ss.pth', hf_hub_id='timm/', - img_size=(3, 640, 640), + input_size=(3, 640, 640), ), 'mask2former_flash_intern_image_large.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_l_640_160k_ade20k_ss.pth', hf_hub_id='timm/', - img_size=(3, 640, 640), + input_size=(3, 640, 640), ), 'upernet_flash_intern_image_tiny.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_t_512_160k_ade20k.pth', hf_hub_id='timm/', - img_size=(3, 512, 512), + input_size=(3, 512, 512), ), 'upernet_flash_intern_image_small.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_s_512_160k_ade20k.pth', hf_hub_id='timm/', - img_size=(3, 512, 512), + input_size=(3, 512, 512), ), 'upernet_flash_intern_image_base.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_b_512_160k_ade20k.pth', hf_hub_id='timm/', - img_size=(3, 512, 512), + input_size=(3, 512, 512), ), 'upernet_flash_intern_image_large.640_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_l_640_160k_ade20k.pth', hf_hub_id='timm/', - img_size=(3, 640, 640), + input_size=(3, 640, 640), ), }) @@ -1295,12 +1303,21 @@ def _create_flash_intern_image(variant: str, pretrained: bool = False, **kwargs) **kwargs ) + +def _check_pretrained_available(pretrained: bool): + if dcn_version == 'DCNv4': + return pretrained + + warnings.warn('DCNv4 is not installed, cannot load pretrained weights') + return False + + @register_model def flash_intern_image_tiny(pretrained=False, **kwarg): """ FlashInternImage-T, trained on ImageNet-1k, for classification. """ - pretrained = pretrained and dcn_version == 'DCNv4' + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=64, @@ -1318,7 +1335,7 @@ def flash_intern_image_small(pretrained=False, **kwarg): """ FlashInternImage-S, trained on ImageNet-1k, for classification. """ - pretrained = pretrained and dcn_version == 'DCNv4' + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=80, @@ -1339,7 +1356,7 @@ def flash_intern_image_base(pretrained=False, **kwarg): """ FlashInternImage-B, trained on ImageNet-1k, for classification. """ - pretrained = pretrained and dcn_version == 'DCNv4' + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=112, @@ -1360,7 +1377,7 @@ def flash_intern_image_large(pretrained=False, **kwarg): """ FlashInternImage-L, trained on ImageNet-1k, for classification. """ - pretrained = pretrained and dcn_version == 'DCNv4' + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=160, @@ -1383,6 +1400,7 @@ def cascade_flash_intern_image_large(pretrained=False, **kwargs): """ CascadeFlashInternImage-L, trained on COCO, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=160, @@ -1406,6 +1424,7 @@ def dino_4scale_flash_intern_image_tiny(pretrained=False, **kwargs): """ FlashInternImage-T, trained on ImageNet-1K, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=64, @@ -1427,6 +1446,7 @@ def dino_4scale_flash_intern_image_small(pretrained=False, **kwargs): """ FlashInternImage-S, trained on ImageNet-1K, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=80, @@ -1449,6 +1469,7 @@ def dino_4scale_flash_intern_image_base(pretrained=False, **kwargs): """ FlashInternImage-B, trained on ImageNet-1K, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=112, @@ -1471,6 +1492,7 @@ def dino_4scale_flash_intern_image_large(pretrained=False, **kwargs): """ FlashInternImage-L, trained on ImageNet-22K, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=160, @@ -1495,6 +1517,7 @@ def mask_rcnn_flash_intern_image_tiny(pretrained=False, **kwargs): """ FlashInternImage-T, trained on COCO, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=64, @@ -1516,6 +1539,7 @@ def mask_rcnn_flash_intern_image_small(pretrained=False, **kwargs): """ FlashInternImage-S, trained on COCO, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=80, @@ -1538,6 +1562,7 @@ def mask_rcnn_flash_intern_image_base(pretrained=False, **kwargs): """ FlashInternImage-B, trained on COCO, used as backbone for detection. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=112, @@ -1560,6 +1585,7 @@ def mask2former_flash_intern_image_tiny(pretrained=False, **kwargs): """ FlashInternImage-T, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=64, @@ -1581,6 +1607,7 @@ def mask2former_flash_intern_image_small(pretrained=False, **kwargs): """ FlashInternImage-S, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=80, @@ -1603,6 +1630,7 @@ def mask2former_flash_intern_image_base(pretrained=False, **kwargs): """ FlashInternImage-B, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=112, @@ -1625,6 +1653,7 @@ def mask2former_flash_intern_image_large(pretrained=False, **kwargs): """ FlashInternImage-L, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=160, @@ -1649,6 +1678,7 @@ def upernet_flash_intern_image_tiny(pretrained=False, **kwargs): """ FlashInternImage-T, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=64, @@ -1670,6 +1700,7 @@ def upernet_flash_intern_image_small(pretrained=False, **kwargs): """ FlashInternImage-S, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=80, @@ -1692,6 +1723,7 @@ def upernet_flash_intern_image_base(pretrained=False, **kwargs): """ FlashInternImage-B, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=112, @@ -1714,6 +1746,7 @@ def upernet_flash_intern_image_large(pretrained=False, **kwargs): """ FlashInternImage-L, trained on ADE20K, used as backbone for segmentation. """ + pretrained = _check_pretrained_available(pretrained) model_arg = dict( core_op='DCNv4', channels=160, From 5c2bdb54fb416c81c90b7dc8177586be1fe64099 Mon Sep 17 00:00:00 2001 From: Pig Date: Thu, 2 May 2024 16:10:17 +0800 Subject: [PATCH 03/13] Fix bugs of pretrained config and weight loaded of FlashInternImage --- timm/models/flash_intern_image.py | 84 ++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 29 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index a5b299ba67..b3baba5079 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -38,9 +38,6 @@ import DCNv4 except ImportError: dcn_version = 'DCNv3' - warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\n\ - By default using DCNv3 pure pytorch implementation instead, which will affect the performance.\n\ - Suggesting install DCNv4 by `pip install DCNv4`') class to_channels_first(nn.Module): @@ -782,6 +779,7 @@ def __init__(self, self.depth = depth self.post_norm = post_norm self.center_feature_scale = center_feature_scale + self.grad_checkpoint = False self.blocks = nn.ModuleList([ InternImageLayer( @@ -903,7 +901,9 @@ def __init__(self, if dcn_version == 'DCNv4' and core_op == 'DCNv4': core_op = 'DCNv4' else: - warnings.warn('DCNv4 is not installed, use DCNv3 instead') + warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\n\ + By default using DCNv3 pure pytorch implementation instead, which will affect the performance.\n\ + Suggesting install DCNv4 by `pip install DCNv4`') core_op = 'DCNv3' self.core_op = core_op self.num_classes = num_classes @@ -1177,115 +1177,140 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: default_cfgs = generate_default_cfgs({ 'flash_intern_image_tiny.224_in1k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_t_1k_224.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='flash_intern_image_t_1k_224.pth' ), 'flash_intern_image_small.224_in1k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_s_1k_224.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='flash_intern_image_s_1k_224.pth' ), 'flash_intern_image_base.224_in1k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_b_1k_224.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='flash_intern_image_b_1k_224.pth' ), 'flash_intern_image_large.384_in22k_ft_1k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_l_22kto1k_384.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='flash_intern_image_l_22kto1k_384.pth', input_size=(3, 384, 384), ), 'flash_intern_image_large.384_in22k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_l_22k_384.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='flash_intern_image_l_22k_384.pth', input_size=(3, 384, 384), num_classes=21841, ), 'cascade_flash_intern_image_large.fpn_1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/cascade_flash_internimage_l_fpn_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='cascade_flash_internimage_l_fpn_1x_coco.pth', ), 'cascade_flash_intern_image_large.fpn_3x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/cascade_flash_internimage_l_fpn_3x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='cascade_flash_internimage_l_fpn_3x_coco.pth', ), 'dino_4scale_flash_intern_image_tiny.1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_t_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='dino_4scale_flash_internimage_t_1x_coco.pth', ), 'dino_4scale_flash_intern_image_small.1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_s_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='dino_4scale_flash_internimage_s_1x_coco.pth', ), 'dino_4scale_flash_intern_image_base.1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_b_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='dino_4scale_flash_internimage_b_1x_coco.pth', ), 'dino_4scale_flash_intern_image_large.1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/dino_4scale_flash_internimage_l_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='dino_4scale_flash_internimage_l_1x_coco.pth', input_size=(3, 384, 384), num_classes=21841, ), 'mask_rcnn_flash_intern_image_tiny.fpn_1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_t_fpn_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask_rcnn_flash_internimage_t_fpn_1x_coco.pth', ), 'mask_rcnn_flash_intern_image_tiny.fpn_3x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_s_fpn_3x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask_rcnn_flash_internimage_s_fpn_3x_coco.pth', ), 'mask_rcnn_flash_intern_image_small.fpn_1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_s_fpn_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask_rcnn_flash_internimage_s_fpn_1x_coco.pth', ), 'mask_rcnn_flash_intern_image_small.fpn_3x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_s_fpn_3x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask_rcnn_flash_internimage_s_fpn_3x_coco.pth', ), 'mask_rcnn_flash_intern_image_base.fpn_1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_b_fpn_1x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask_rcnn_flash_internimage_b_fpn_1x_coco.pth', ), 'mask_rcnn_flash_intern_image_base.fpn_3x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask_rcnn_flash_internimage_b_fpn_3x_coco.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask_rcnn_flash_internimage_b_fpn_3x_coco.pth', ), 'mask2former_flash_intern_image_tiny.512_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_t_512_160k_ade20k_ss.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask2former_flash_internimage_t_512_160k_ade20k_ss.pth', input_size=(3, 512, 512), ), 'mask2former_flash_intern_image_small.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_s_640_160k_ade20k_ss.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask2former_flash_internimage_s_640_160k_ade20k_ss.pth', input_size=(3, 640, 640), ), 'mask2former_flash_intern_image_base.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_b_640_160k_ade20k_ss.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask2former_flash_internimage_b_640_160k_ade20k_ss.pth', input_size=(3, 640, 640), ), 'mask2former_flash_intern_image_large.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_l_640_160k_ade20k_ss.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='mask2former_flash_internimage_l_640_160k_ade20k_ss.pth', input_size=(3, 640, 640), ), 'upernet_flash_intern_image_tiny.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_t_512_160k_ade20k.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='upernet_flash_internimage_t_512_160k_ade20k.pth', input_size=(3, 512, 512), ), 'upernet_flash_intern_image_small.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_s_512_160k_ade20k.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='upernet_flash_internimage_s_512_160k_ade20k.pth', input_size=(3, 512, 512), ), 'upernet_flash_intern_image_base.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_b_512_160k_ade20k.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='upernet_flash_internimage_b_512_160k_ade20k.pth', input_size=(3, 512, 512), ), 'upernet_flash_intern_image_large.640_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_l_640_160k_ade20k.pth', - hf_hub_id='timm/', + hf_hub_id='OpenGVLab/DCNv4', + hf_hub_filename='upernet_flash_internimage_l_640_160k_ade20k.pth', input_size=(3, 640, 640), ), }) @@ -1299,6 +1324,7 @@ def _create_flash_intern_image(variant: str, pretrained: bool = False, **kwargs) variant, pretrained=pretrained, pretrained_filter_fn=checkpoint_filter_fn, + pretrained_strict=False, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs ) From 0590b1e15589d3d8115b267ba0c42b58afc9ca16 Mon Sep 17 00:00:00 2001 From: Pig Date: Fri, 3 May 2024 00:15:28 +0800 Subject: [PATCH 04/13] Fix bugs of forward and backward tests of FlashInternImage --- timm/models/flash_intern_image.py | 39 ++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index b3baba5079..5efb9929e9 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -303,17 +303,23 @@ def _reset_parameters(self): xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) - def forward(self, input): + def forward(self, input, shape=None): """ :param query (N, H, W, C) :return output (N, H, W, C) """ - N, H, W, _ = input.shape + # N, H, W, _ = input.shape + N, L, C = input.shape + if shape is not None: + H, W = shape + else: + H, W = int(L**0.5), int(L**0.5) - x = self.input_proj(input) + x = input.reshape(N, H, W, -1) + x = self.input_proj(x) x_proj = x - x1 = input.permute(0, 3, 1, 2) + x1 = input.reshape(N, H, W, -1).permute(0, 3, 1, 2) x1 = self.dw_conv(x1) offset = self.offset(x1) mask = self.mask(x1).reshape(N, H, W, self.group, -1) @@ -335,7 +341,7 @@ def forward(self, input): 1, 1, 1, 1, self.channels // self.group).flatten(-2) x = x * (1 - center_feature_scale) + x_proj * center_feature_scale x = self.output_proj(x) - + x = x.reshape(N, L, -1) return x # --- DCNv3 pure pytorch implementation finished --- # @@ -700,14 +706,14 @@ def forward(self, x, shape, level_idx=0): def _inner_forward(x, shape, level_idx): if not self.layer_scale: if self.post_norm: - x = x + self.drop_path(self.norm1(self.dcn(x, shape, level_idx))) + x = x + self.drop_path(self.norm1(self.dcn(x, shape))) x = x + self.drop_path(self.norm2(self.mlp(x, shape, level_idx))) elif self.res_post_norm: # for InternImage-H/G - x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x), shape, level_idx))) + x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x), shape))) x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x), shape, level_idx))) else: - x = x + self.drop_path(self.dcn(self.norm1(x), shape, level_idx)) + x = x + self.drop_path(self.dcn(self.norm1(x), shape)) x = x + self.drop_path(self.mlp(self.norm2(x), shape, level_idx)) return x if self.post_norm: @@ -1086,8 +1092,8 @@ def forward_features_no_clip_projector(self, x): h, w = shape x = x.view(N, h, w, -1) x = self.conv_head(x.permute(0, 3, 1, 2)) - x = self.avgpool(x) - x = torch.flatten(x, 1) + # x = self.avgpool(x) + # x = torch.flatten(x, 1) return x def forward_features_seq_out(self, x): # for detection or segmentation @@ -1117,9 +1123,9 @@ def forward_clip_projector(self, x): # for InternImage-H/G x3 = self.dcnv3_head_x3(x3) x = x + x3 - x = x.flatten(-2).transpose(1, 2).contiguous() - x = self.clip_projector(x) - x = self.fc_norm(x) + # x = x.flatten(-2).transpose(1, 2).contiguous() + # x = self.clip_projector(x) + # x = self.fc_norm(x) return x @@ -1132,6 +1138,13 @@ def forward_features(self, x): def forward(self, x): x = self.forward_features(x) + if self.use_clip_projector: + x = x.flatten(-2).transpose(1, 2).contiguous() + x = self.clip_projector(x) + x = self.fc_norm(x) + else: + x = self.avgpool(x) + x = torch.flatten(x, 1) x = self.head(x) return x From 0f26e031ad04f44b33e2f35ba8f05535c97a3905 Mon Sep 17 00:00:00 2001 From: Pig Date: Fri, 3 May 2024 01:46:46 +0800 Subject: [PATCH 05/13] Fix bugs of default_cfgs and forward_features tests of FlashInternImage --- timm/models/flash_intern_image.py | 311 ++++++++++++++++++++++-------- 1 file changed, 235 insertions(+), 76 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index 5efb9929e9..c76da1b2fb 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -20,12 +20,13 @@ from collections import OrderedDict import torch.utils.checkpoint as checkpoint from timm.models.layers import trunc_normal_, DropPath +from timm.layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from ._registry import register_model, generate_default_cfgs from ._builder import build_model_with_cfg import torch.nn.functional as F from ._manipulate import checkpoint_seq -from typing import Dict, Any +from typing import Dict, Any, Tuple, Optional, List import warnings import logging @@ -58,11 +59,13 @@ def forward(self, x): return x.permute(0, 2, 3, 1) -def build_norm_layer(dim, - norm_layer, - in_format='channels_last', - out_format='channels_last', - eps=1e-6): +def build_norm_layer( + dim, + norm_layer, + in_format='channels_last', + out_format='channels_last', + eps=1e-6 + ): layers = [] if norm_layer == 'BN': if in_format == 'channels_last': @@ -93,7 +96,18 @@ def build_act_layer(act_layer): raise NotImplementedError(f'build_act_layer does not support {act_layer}') -def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): +def _get_reference_points( + spatial_shapes: List[int], + device: Optional[torch.device], + kernel_h: int, + kernel_w: int, + dilation_h: int, + dilation_w: int, + pad_h: int=0, + pad_w: int=0, + stride_h: int=1, + stride_w: int=1 + ): _, H_, W_, _ = spatial_shapes H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 @@ -124,7 +138,15 @@ def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h return ref -def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): +def _generate_dilation_grids( + spatial_shapes: List[int], + kernel_h: int, + kernel_w: int, + dilation_h: int, + dilation_w: int, + group: int, + device: Optional[torch.device], + ): _, H_, W_, _ = spatial_shapes points_list = [] x, y = torch.meshgrid( @@ -150,10 +172,20 @@ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dil def dcnv3_core_pytorch( - input, offset, mask, kernel_h, - kernel_w, stride_h, stride_w, pad_h, - pad_w, dilation_h, dilation_w, group, - group_channels, offset_scale): + input, + offset, + mask, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + group: int, + group_channels: int, + offset_scale: float): # for debug and test only, # need to use cuda version instead input = F.pad( @@ -285,6 +317,9 @@ def __init__( self.input_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels) self._reset_parameters() + self.center_feature_scale_module = CenterFeatureScaleModule() + self.center_feature_scale_proj_weight = torch.zeros((group, channels), dtype=torch.float) + self.center_feature_scale_proj_bias = torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ) if center_feature_scale: self.center_feature_scale_proj_weight = nn.Parameter( @@ -303,17 +338,21 @@ def _reset_parameters(self): xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) - def forward(self, input, shape=None): + def forward(self, input, shape: Optional[Tuple[int, int]]=None): """ :param query (N, H, W, C) :return output (N, H, W, C) """ # N, H, W, _ = input.shape - N, L, C = input.shape - if shape is not None: - H, W = shape + if len(input.shape) == 3: + N, L, C = input.shape + if shape is not None: + H, W = shape + else: + H, W = int(L**0.5), int(L**0.5) else: - H, W = int(L**0.5), int(L**0.5) + N, H, W, C = input.shape + L = H * W x = input.reshape(N, H, W, -1) x = self.input_proj(x) @@ -341,7 +380,8 @@ def forward(self, input, shape=None): 1, 1, 1, 1, self.channels // self.group).flatten(-2) x = x * (1 - center_feature_scale) + x_proj * center_feature_scale x = self.output_proj(x) - x = x.reshape(N, L, -1) + if len(input.shape) == 3: + x = x.reshape(N, L, -1) return x # --- DCNv3 pure pytorch implementation finished --- # @@ -570,14 +610,25 @@ def __init__(self, channels, norm_layer='LN'): 'channels_first', 'channels_first') - def forward(self, x, shape=None): - H, W = shape - N, HW, C = x.shape + def forward(self, x, shape: Optional[Tuple[int, int]]=None): + input_shape = len(x.shape) + if input_shape == 3: + N, HW, C = x.shape + if shape is not None: + H, W = shape + else: + H, W = int(HW**0.5), int(HW**0.5) + else: + N, H, W, C = x.shape + HW = H * W x = x.view(N, H, W, C) x = self.conv(x.permute(0, 3, 1, 2)) x = self.norm(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).permute(0, 2, 1) + if input_shape == 3: + H, W = x.size(2), x.size(3) + x = x.flatten(2).permute(0, 2, 1) + else: + x = x.permute(0, 2, 3, 1) return x, (H, W) @@ -608,7 +659,7 @@ def __init__(self, self.drop = nn.Dropout(drop) - def forward(self, x, shape, level_idx=0): + def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): x = self.fc1(x) x = self.act(x) x = self.drop(x) @@ -692,42 +743,52 @@ def __init__(self, mlp_fc2_bias=mlp_fc2_bias ) self.layer_scale = layer_scale is not None + self.gamma1 = torch.ones(channels) + self.gamma2 = torch.ones(channels) if self.layer_scale: self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels), requires_grad=True) self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels), requires_grad=True) + self.res_post_norm = res_post_norm + self.res_post_norm1 = nn.Sequential() + self.res_post_norm2 = nn.Sequential() if res_post_norm: self.res_post_norm1 = build_norm_layer(channels, 'LN') self.res_post_norm2 = build_norm_layer(channels, 'LN') - def forward(self, x, shape, level_idx=0): - - def _inner_forward(x, shape, level_idx): - if not self.layer_scale: - if self.post_norm: - x = x + self.drop_path(self.norm1(self.dcn(x, shape))) - x = x + self.drop_path(self.norm2(self.mlp(x, shape, level_idx))) - elif self.res_post_norm: # for InternImage-H/G - x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x), shape))) - x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x), shape, level_idx))) - - else: - x = x + self.drop_path(self.dcn(self.norm1(x), shape)) - x = x + self.drop_path(self.mlp(self.norm2(x), shape, level_idx)) - return x + + def _inner_forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int): + if not self.layer_scale: if self.post_norm: - x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x, shape))) - x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x, shape, level_idx))) + x = x + self.drop_path(self.norm1(self.dcn(x, shape))) + x = x + self.drop_path(self.norm2(self.mlp(x, shape, level_idx))) + elif self.res_post_norm: # for InternImage-H/G + x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x), shape))) + x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x), shape, level_idx))) + else: - x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x), shape)) - x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x), shape, level_idx)) + x = x + self.drop_path(self.dcn(self.norm1(x), shape)) + x = x + self.drop_path(self.mlp(self.norm2(x), shape, level_idx)) return x + if self.post_norm: + x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x, shape))) + x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x, shape, level_idx))) + else: + x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x), shape)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x), shape, level_idx)) + return x + + @torch.jit.ignore + def forward_checkpoint(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): + x = checkpoint.checkpoint(self._inner_forward, x, shape, level_idx) + return x + def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): if self.with_cp and x.requires_grad: - x = checkpoint.checkpoint(_inner_forward, x, shape, level_idx) + x = self.forward_checkpoint(x, shape, level_idx) else: - x = _inner_forward(x, shape, level_idx) + x = self._inner_forward(x, shape, level_idx) return x @@ -777,7 +838,7 @@ def __init__(self, dcn_output_bias=False, mlp_fc2_bias=False, dw_kernel_size=None, # for InternImage-H/G - post_norm_block_ids=None, # for InternImage-H/G + post_norm_block_ids: Optional[List[int]]=None, # for InternImage-H/G res_post_norm=False, # for InternImage-H/G center_feature_scale=False): # for InternImage-H/G super().__init__() @@ -811,35 +872,74 @@ def __init__(self, ]) if not self.post_norm or center_feature_scale: self.norm = build_norm_layer(channels, 'LN') - self.post_norm_block_ids = post_norm_block_ids + + self.if_post_norm = post_norm_block_ids is not None + self.post_norm_block_ids: List[int] = [0] + self.post_norms = nn.ModuleList() if post_norm_block_ids is not None: # for InternImage-H/G + self.post_norm_block_ids = post_norm_block_ids self.post_norms = nn.ModuleList( [build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids] ) self.downsample = downsample_layer( channels=channels, norm_layer=norm_layer) if downsample else None + @torch.jit.ignore + def _forward_post_norm(self, x, i: int): + index = self.post_norm_block_ids.index(i) + x = self.post_norms[index](x) # for InternImage-H/G + return x + + @torch.jit.ignore + def forward_return_wo_downsample(self, x, shape: Optional[Tuple[int, int]]=None, level_idx: int=0): + for i, blk in enumerate(self.blocks): + if self.grad_checkpoint and not torch.jit.is_scripting(): + x = checkpoint_seq(blk, x) + else: + x = blk(x, shape=shape, level_idx=level_idx) + if self.if_post_norm and (i in self.post_norm_block_ids): + self._forward_post_norm(x, i) + if not self.post_norm or self.center_feature_scale: + x = self.norm(x) + + x_ = x.clone() + + if self.downsample is not None: + x, shape = self.downsample(x, shape=shape) - def forward(self, x, return_wo_downsample=False, shape=None, level_idx=0 - ): + return x, x_, shape + + def forward_shape(self, x, shape: Tuple[int, int], level_idx: int=0): for i, blk in enumerate(self.blocks): if self.grad_checkpoint and not torch.jit.is_scripting(): x = checkpoint_seq(blk, x) else: x = blk(x, shape=shape, level_idx=level_idx) - if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids): - index = self.post_norm_block_ids.index(i) - x = self.post_norms[index](x) # for InternImage-H/G + if self.if_post_norm and (i in self.post_norm_block_ids): + self._forward_post_norm(x, i) if not self.post_norm or self.center_feature_scale: x = self.norm(x) - if return_wo_downsample: - x_ = x.clone() + if self.downsample is not None: x, shape = self.downsample(x, shape=shape) - if return_wo_downsample: - return x, x_, shape return x, shape + + def forward(self, x, shape: Optional[Tuple[int, int]]=None, level_idx: int=0): + for i, blk in enumerate(self.blocks): + if self.grad_checkpoint and not torch.jit.is_scripting(): + x = checkpoint_seq(blk, x) + else: + x = blk(x, shape=shape, level_idx=level_idx) + if self.if_post_norm and (i in self.post_norm_block_ids): + self._forward_post_norm(x, i) + if not self.post_norm or self.center_feature_scale: + x = self.norm(x) + + if self.downsample is not None: + x, shape = self.downsample(x, shape=shape) + + return x class FlashInternImage(nn.Module): @@ -868,6 +968,7 @@ class FlashInternImage(nn.Module): mlp_fc2_bias (bool): Whether to use mlp fc2 bias. Default: False dcn_output_bias (bool): Whether to use dcn output bias. Default: False dw_kernel_size (int): Size of the dwconv. Default: None + global_pool (str): Global pooling type. Default: 'avg' use_clip_projector (bool): Whether to use clip projector. Default: False level2_post_norm (bool): Whether to use level2 post norm. Default: False level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None @@ -896,6 +997,7 @@ def __init__(self, mlp_fc2_bias=False, dcn_output_bias=False, dw_kernel_size=None, + global_pool='avg', use_clip_projector=False, # for InternImage-H/G level2_post_norm=False, # for InternImage-H/G level2_post_norm_block_ids=None, # for InternImage-H/G @@ -904,7 +1006,7 @@ def __init__(self, out_indices=(0, 1, 2, 3), **kwargs): super().__init__() - if dcn_version == 'DCNv4' and core_op == 'DCNv4': + if dcn_version == 'DCNv4': core_op = 'DCNv4' else: warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\n\ @@ -919,9 +1021,12 @@ def __init__(self, self.num_features = int(channels * 2**(self.num_levels - 1)) self.post_norm = post_norm self.mlp_ratio = mlp_ratio + self.act_layer = act_layer self.use_clip_projector = use_clip_projector self.level2_post_norm_block_ids = level2_post_norm_block_ids self.out_indices = out_indices + self.output_fmt = 'NHWC' + self.feature_info = [] _logger.info(f'use core type: {core_op}') _logger.info(f'using activation layer: {act_layer}') _logger.info(f'using main norm layer: {norm_layer}') @@ -936,6 +1041,7 @@ def __init__(self, act_layer=act_layer, norm_layer=norm_layer) self.pos_drop = nn.Dropout(p=drop_rate) + self.feature_info.append(dict(num_chs=channels, reduction=2, module='patch_embed')) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) @@ -944,7 +1050,7 @@ def __init__(self, for i in range(len(dpr)): dpr[i] = drop_path_rate - self.levels = nn.ModuleList() + self.levels = nn.Sequential() for i in range(self.num_levels): post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( i == 2) else None # for InternImage-H/G @@ -972,7 +1078,13 @@ def __init__(self, res_post_norm=res_post_norm, # for InternImage-H/G center_feature_scale=center_feature_scale # for InternImage-H/G ) - self.levels.append(level) + self.levels.add_module(str(i), level) + if i < self.num_levels - 1: + self.feature_info.append( + dict(num_chs=int(channels * 2 ** (i + 1)), reduction=2 ** (i + 2), module=f'levels.{i}')) + else: + self.feature_info.append( + dict(num_chs=int(channels * 2 ** i), reduction=2 ** (i + 1), module=f'levels.{i}')) if not use_clip_projector: # for InternImage-T/S/B/L/XL self.conv_head = nn.Sequential( @@ -1007,7 +1119,10 @@ def __init__(self, self.head = nn.Linear( clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.pool_type = global_pool + self.global_pool = SelectAdaptivePool2d(output_size=(1, 1), pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool != '' else nn.Identity() self.num_layers = len(depths) self.apply(self._init_weights) self.apply(self._init_deform_weights) @@ -1027,6 +1142,7 @@ def _init_deform_weights(self, m): elif isinstance(m, DCNv3_pytorch): m._reset_parameters() + @torch.jit.ignore def init_weights(self): self.apply(self._init_weights) self.apply(self._init_deform_weights) @@ -1035,8 +1151,21 @@ def init_weights(self): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=None): + def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes + if num_classes == 0: + self.conv_head = nn.Sequential( + nn.Conv2d(self.num_features, + int(self.num_features), + kernel_size=1, + bias=False), + build_norm_layer(int(self.num_features), 'BN', + 'channels_first', 'channels_first'), + build_act_layer(self.act_layer)) + + self.global_pool = SelectAdaptivePool2d(output_size=(1, 1), pool_type=global_pool) + self.pool_type = global_pool + self.flatten = nn.Flatten(1) if global_pool != '' else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) \ if num_classes > 0 else nn.Identity() @@ -1085,17 +1214,17 @@ def forward_features_no_clip_projector(self, x): x = x.view(N, H*W, C) shape=(H, W) - seq_out = [] for level_idx, level in enumerate(self.levels): - old_shape = shape - x, shape = level(x, shape=shape) + # old_shape = shape + x, shape = level.forward_shape(x, shape=shape) h, w = shape x = x.view(N, h, w, -1) - x = self.conv_head(x.permute(0, 3, 1, 2)) + # x = self.conv_head(x) # x = self.avgpool(x) # x = torch.flatten(x, 1) return x + @torch.jit.ignore def forward_features_seq_out(self, x): # for detection or segmentation x = self.patch_embed(x) N, H, W, C = x.shape @@ -1104,11 +1233,12 @@ def forward_features_seq_out(self, x): # for detection or segmentation seq_out = [] for level_idx, level in enumerate(self.levels): old_shape = shape - x, x_ , shape = level(x, return_wo_downsample=True, shape=shape, level_idx=level_idx) + x, x_ , shape = level.forward_return_wo_downsample(x, shape=shape, level_idx=level_idx) h, w= old_shape seq_out.append(x_.reshape(N, h, w, -1).permute(0, 3, 1, 2)) return seq_out + @torch.jit.ignore def forward_clip_projector(self, x): # for InternImage-H/G xs = self.forward_features_seq_out(x) x1, x2, x3, x4 = xs @@ -1128,7 +1258,7 @@ def forward_clip_projector(self, x): # for InternImage-H/G # x = self.fc_norm(x) return x - + def forward_features(self, x): if self.use_clip_projector: # for InternImage-H/G x = self.forward_clip_projector(x) @@ -1136,16 +1266,33 @@ def forward_features(self, x): x = self.forward_features_no_clip_projector(x) return x - def forward(self, x): - x = self.forward_features(x) + def forward_head_no_clip_projector(self, x): + x = self.conv_head(x.permute(0, 3, 1, 2)) + x = self.global_pool(x) + x = self.flatten(x) + if self.pool_type == '': + x = x.permute(0, 2, 3, 1) + x = self.head(x) + return x + + @torch.jit.ignore + def forward_head_clip_projector(self, x): + x = x.flatten(-2).transpose(1, 2).contiguous() + x = self.clip_projector(x) + x = self.fc_norm(x) + x = self.head(x) + return x + + def forward_head(self, x): if self.use_clip_projector: - x = x.flatten(-2).transpose(1, 2).contiguous() - x = self.clip_projector(x) - x = self.fc_norm(x) + x = self.forward_head_clip_projector(x) else: - x = self.avgpool(x) - x = torch.flatten(x, 1) - x = self.head(x) + x = self.forward_head_no_clip_projector(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x @@ -1175,7 +1322,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), - 'pool_size': None, + 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, @@ -1208,18 +1355,21 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='flash_intern_image_l_22kto1k_384.pth', input_size=(3, 384, 384), + pool_size=(12, 12), ), 'flash_intern_image_large.384_in22k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_l_22k_384.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='flash_intern_image_l_22k_384.pth', input_size=(3, 384, 384), + pool_size=(12, 12), num_classes=21841, ), 'cascade_flash_intern_image_large.fpn_1x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/cascade_flash_internimage_l_fpn_1x_coco.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='cascade_flash_internimage_l_fpn_1x_coco.pth', + ), 'cascade_flash_intern_image_large.fpn_3x_coco': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/cascade_flash_internimage_l_fpn_3x_coco.pth', @@ -1246,6 +1396,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='dino_4scale_flash_internimage_l_1x_coco.pth', input_size=(3, 384, 384), + pool_size=(12, 12), num_classes=21841, ), 'mask_rcnn_flash_intern_image_tiny.fpn_1x_coco': _cfg( @@ -1283,48 +1434,56 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='mask2former_flash_internimage_t_512_160k_ade20k_ss.pth', input_size=(3, 512, 512), + pool_size=(16, 16), ), 'mask2former_flash_intern_image_small.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_s_640_160k_ade20k_ss.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='mask2former_flash_internimage_s_640_160k_ade20k_ss.pth', input_size=(3, 640, 640), + pool_size=(20, 20), ), 'mask2former_flash_intern_image_base.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_b_640_160k_ade20k_ss.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='mask2former_flash_internimage_b_640_160k_ade20k_ss.pth', input_size=(3, 640, 640), + pool_size=(20, 20), ), 'mask2former_flash_intern_image_large.640_160k_ade20k_ss': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/mask2former_flash_internimage_l_640_160k_ade20k_ss.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='mask2former_flash_internimage_l_640_160k_ade20k_ss.pth', input_size=(3, 640, 640), + pool_size=(20, 20), ), 'upernet_flash_intern_image_tiny.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_t_512_160k_ade20k.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='upernet_flash_internimage_t_512_160k_ade20k.pth', input_size=(3, 512, 512), + pool_size=(16, 16), ), 'upernet_flash_intern_image_small.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_s_512_160k_ade20k.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='upernet_flash_internimage_s_512_160k_ade20k.pth', input_size=(3, 512, 512), + pool_size=(16, 16), ), 'upernet_flash_intern_image_base.512_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_b_512_160k_ade20k.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='upernet_flash_internimage_b_512_160k_ade20k.pth', input_size=(3, 512, 512), + pool_size=(16, 16), ), 'upernet_flash_intern_image_large.640_160k_ade20k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/upernet_flash_internimage_l_640_160k_ade20k.pth', hf_hub_id='OpenGVLab/DCNv4', hf_hub_filename='upernet_flash_internimage_l_640_160k_ade20k.pth', input_size=(3, 640, 640), + pool_size=(20, 20), ), }) From 7c3cb3e2f017b39daf3041704cb19c151cc16b79 Mon Sep 17 00:00:00 2001 From: Pig Date: Fri, 3 May 2024 23:45:59 +0800 Subject: [PATCH 06/13] Fix bugs of torchscript test of FlashInternImage --- timm/models/flash_intern_image.py | 50 +++++++++++++------------------ 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index c76da1b2fb..8511a9b42e 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -33,6 +33,7 @@ __all__ = ['FlashInternImage'] _logger = logging.getLogger(__name__) +torch.fx.wrap('len') dcn_version = 'DCNv4' try: @@ -114,16 +115,12 @@ def _get_reference_points( ref_y, ref_x = torch.meshgrid( torch.linspace( - # pad_h + 0.5, - # H_ - pad_h - 0.5, (dilation_h * (kernel_h - 1)) // 2 + 0.5, (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, H_out, dtype=torch.float32, device=device), torch.linspace( - # pad_w + 0.5, - # W_ - pad_w - 0.5, (dilation_w * (kernel_w - 1)) // 2 + 0.5, (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, W_out, @@ -152,14 +149,14 @@ def _generate_dilation_grids( x, y = torch.meshgrid( torch.linspace( -((dilation_w * (kernel_w - 1)) // 2), - -((dilation_w * (kernel_w - 1)) // 2) + - (kernel_w - 1) * dilation_w, kernel_w, + -((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w, + kernel_w, dtype=torch.float32, device=device), torch.linspace( -((dilation_h * (kernel_h - 1)) // 2), - -((dilation_h * (kernel_h - 1)) // 2) + - (kernel_h - 1) * dilation_h, kernel_h, + -((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h, + kernel_h, dtype=torch.float32, device=device)) @@ -344,21 +341,13 @@ def forward(self, input, shape: Optional[Tuple[int, int]]=None): :return output (N, H, W, C) """ # N, H, W, _ = input.shape - if len(input.shape) == 3: - N, L, C = input.shape - if shape is not None: - H, W = shape - else: - H, W = int(L**0.5), int(L**0.5) - else: - N, H, W, C = input.shape - L = H * W + N, H, W, C = input.shape + L = H * W - x = input.reshape(N, H, W, -1) - x = self.input_proj(x) + x = self.input_proj(input) x_proj = x - x1 = input.reshape(N, H, W, -1).permute(0, 3, 1, 2) + x1 = input.permute(0, 3, 1, 2) x1 = self.dw_conv(x1) offset = self.offset(x1) mask = self.mask(x1).reshape(N, H, W, self.group, -1) @@ -380,8 +369,6 @@ def forward(self, input, shape: Optional[Tuple[int, int]]=None): 1, 1, 1, 1, self.channels // self.group).flatten(-2) x = x * (1 - center_feature_scale) + x_proj * center_feature_scale x = self.output_proj(x) - if len(input.shape) == 3: - x = x.reshape(N, L, -1) return x # --- DCNv3 pure pytorch implementation finished --- # @@ -608,11 +595,11 @@ def __init__(self, channels, norm_layer='LN'): bias=False) self.norm = build_norm_layer(2 * channels, norm_layer, 'channels_first', 'channels_first') + self.dcn_version = dcn_version def forward(self, x, shape: Optional[Tuple[int, int]]=None): - input_shape = len(x.shape) - if input_shape == 3: + if self.dcn_version == 'DCNv4': N, HW, C = x.shape if shape is not None: H, W = shape @@ -624,11 +611,12 @@ def forward(self, x, shape: Optional[Tuple[int, int]]=None): x = x.view(N, H, W, C) x = self.conv(x.permute(0, 3, 1, 2)) x = self.norm(x) # B C H W - if input_shape == 3: + if self.dcn_version == 'DCNv4': H, W = x.size(2), x.size(3) x = x.flatten(2).permute(0, 2, 1) else: x = x.permute(0, 2, 3, 1) + H, W = x.size(1), x.size(2) return x, (H, W) @@ -780,12 +768,12 @@ def _inner_forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int): return x @torch.jit.ignore - def forward_checkpoint(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): + def forward_checkpoint(self, x, shape: Optional[Tuple[int, int]], level_idx: int = 0): x = checkpoint.checkpoint(self._inner_forward, x, shape, level_idx) return x def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): - if self.with_cp and x.requires_grad: + if self.with_cp: # x = self.forward_checkpoint(x, shape, level_idx) else: x = self._inner_forward(x, shape, level_idx) @@ -1006,6 +994,7 @@ def __init__(self, out_indices=(0, 1, 2, 3), **kwargs): super().__init__() + self.dcn_version = dcn_version if dcn_version == 'DCNv4': core_op = 'DCNv4' else: @@ -1211,7 +1200,9 @@ def lr_decay_keywards(self, decay_ratio=0.87): def forward_features_no_clip_projector(self, x): x = self.patch_embed(x) N, H, W, C = x.shape - x = x.view(N, H*W, C) + + if self.dcn_version == 'DCNv4': + x = x.view(N, H * W, C) shape=(H, W) for level_idx, level in enumerate(self.levels): @@ -1228,7 +1219,8 @@ def forward_features_no_clip_projector(self, x): def forward_features_seq_out(self, x): # for detection or segmentation x = self.patch_embed(x) N, H, W, C = x.shape - x = x.view(N, H*W, C) + if self.dcn_version == 'DCNv4': + x = x.view(N, H * W, C) shape=(H, W) seq_out = [] for level_idx, level in enumerate(self.levels): From b000daa463c66d095b4441e5fa4fdbe5327b60a7 Mon Sep 17 00:00:00 2001 From: Pig Date: Sat, 4 May 2024 00:52:42 +0800 Subject: [PATCH 07/13] Pass tests except torch.fx related of FlashInternImage --- tests/test_models.py | 2 +- timm/models/flash_intern_image.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 7f696dc128..791db73e69 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -484,7 +484,7 @@ def _create_fx_model(model, train=False): return fx_model -EXCLUDE_FX_FILTERS = ['vit_gi*'] +EXCLUDE_FX_FILTERS = ['vit_gi*', '*flash_intern_image*'] # not enough memory to run fx on more models than other tests if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FX_FILTERS += [ diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index 8511a9b42e..d8a027c4be 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -858,6 +858,8 @@ def __init__(self, center_feature_scale=center_feature_scale # for InternImage-H/G ) for i in range(depth) ]) + + self.norm = nn.Sequential() if not self.post_norm or center_feature_scale: self.norm = build_norm_layer(channels, 'LN') From 35358631675dc9f930b4cd363e0d67bfc3479220 Mon Sep 17 00:00:00 2001 From: Pig Date: Mon, 13 May 2024 01:03:44 +0800 Subject: [PATCH 08/13] Optimized code implementation of FlashInternImage --- timm/models/flash_intern_image.py | 951 ++++++++++++++---------------- 1 file changed, 458 insertions(+), 493 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index d8a027c4be..3e1cc79a1e 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -2,7 +2,7 @@ A Pytorch Implementation of Flash Intern Image as decribed in: `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` - - https://arxiv.org/pdf/2103.14030 + - https://arxiv.org/pdf/2211.05778 `DCNv4` - https://arxiv.org/pdf/2401.06197 @@ -35,15 +35,14 @@ _logger = logging.getLogger(__name__) torch.fx.wrap('len') -dcn_version = 'DCNv4' +dcn_version = 'CUDA' try: import DCNv4 except ImportError: - dcn_version = 'DCNv3' + dcn_version = 'pytorch' class to_channels_first(nn.Module): - def __init__(self): super().__init__() @@ -52,7 +51,6 @@ def forward(self, x): class to_channels_last(nn.Module): - def __init__(self): super().__init__() @@ -83,6 +81,7 @@ def build_norm_layer( else: raise NotImplementedError( f'build_norm_layer does not support {norm_layer}') + return nn.Sequential(*layers) @@ -119,18 +118,20 @@ def _get_reference_points( (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, H_out, dtype=torch.float32, - device=device), + device=device + ), torch.linspace( (dilation_w * (kernel_w - 1)) // 2 + 0.5, (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, W_out, dtype=torch.float32, - device=device)) + device=device + ) + ) ref_y = ref_y.reshape(-1)[None] / H_ ref_x = ref_x.reshape(-1)[None] / W_ - ref = torch.stack((ref_x, ref_y), -1).reshape( - 1, H_out, W_out, 1, 2) + ref = torch.stack((ref_x, ref_y), -1).reshape(1, H_out, W_out, 1, 2) return ref @@ -152,13 +153,16 @@ def _generate_dilation_grids( -((dilation_w * (kernel_w - 1)) // 2) + (kernel_w - 1) * dilation_w, kernel_w, dtype=torch.float32, - device=device), + device=device + ), torch.linspace( -((dilation_h * (kernel_h - 1)) // 2), -((dilation_h * (kernel_h - 1)) // 2) + (kernel_h - 1) * dilation_h, kernel_h, dtype=torch.float32, - device=device)) + device=device + ) + ) points_list.extend([x / W_, y / H_]) grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ @@ -168,9 +172,9 @@ def _generate_dilation_grids( return grid -def dcnv3_core_pytorch( +def dcnv4_core_pytorch( input, - offset, + offset, mask, kernel_h: int, kernel_w: int, @@ -182,42 +186,76 @@ def dcnv3_core_pytorch( dilation_w: int, group: int, group_channels: int, - offset_scale: float): - # for debug and test only, - # need to use cuda version instead + offset_scale: float + ): input = F.pad( input, - [0, 0, pad_h, pad_h, pad_w, pad_w]) + [0, 0, pad_h, pad_h, pad_w, pad_w] + ) N_, H_in, W_in, _ = input.shape _, H_out, W_out, _ = offset.shape + # _, H_out, W_out, C_ = offset_mask.shape + # offset = offset_mask[:, :, :, : (C_ * 2) // 3] + # mask = offset_mask[:, :, :, (C_ * 2) // 3:] ref = _get_reference_points( - input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) + input.shape, + input.device, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + pad_h, + pad_w, + stride_h, + stride_w + ) grid = _generate_dilation_grids( - input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) + input.shape, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + group, + input.device + ) spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ - repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) + repeat(1, 1, 1, group * kernel_h * kernel_w).to(input.device) sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ offset * offset_scale / spatial_norm P_ = kernel_h * kernel_w sampling_grids = 2 * sampling_locations - 1 - # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in - input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ - reshape(N_*group, group_channels, H_in, W_in) - # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 - sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ + # N_, H_in, W_in, group * group_channels + # -> N_, H_in * W_in, group * group_channels + # -> N_, group * group_channels, H_in * W_in + # -> N_ * group, group_channels, H_in, W_in + input_ = input.view(N_, H_in * W_in, group * group_channels).transpose(1, 2).\ + reshape(N_ * group, group_channels, H_in, W_in) + # N_, H_out, W_out, group * P_ * 2 + # -> N_, H_out * W_out, group, P_, 2 + # -> N_, group, H_out * W_out, P_, 2 + # -> N_ * group, H_out * W_out, P_, 2 + sampling_grid_ = sampling_grids.view(N_, H_out * W_out, group, P_, 2).transpose(1, 2).\ flatten(0, 1) - # N_*group, group_channels, H_out*W_out, P_ + # N_ * group, group_channels, H_out * W_out, P_ sampling_input_ = F.grid_sample( - input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) + input_, + sampling_grid_, + mode='bilinear', + padding_mode='zeros', + align_corners=False + ) - # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) - mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ - reshape(N_*group, 1, H_out*W_out, P_) - output = (sampling_input_ * mask).sum(-1).view(N_, - group*group_channels, H_out*W_out) + # (N_, H_out, W_out, group * P_) + # -> (N_, H_out * W_out, group, P_) + # -> (N_, group, H_out * W_out, P_) + # -> (N_ * group, 1, H_out * W_out, P_) + mask = mask.view(N_, H_out * W_out, group, P_).transpose(1, 2).\ + reshape(N_ * group, 1, H_out * W_out, P_) + output = (sampling_input_ * mask).sum(-1).\ + view(N_, group * group_channels, H_out * W_out) return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() @@ -231,32 +269,39 @@ def _is_power_of_2(n): class CenterFeatureScaleModule(nn.Module): - def forward(self, + def forward( + self, + query, + center_feature_scale_proj_weight, + center_feature_scale_proj_bias + ): + center_feature_scale = \ + F.linear( query, - center_feature_scale_proj_weight, - center_feature_scale_proj_bias): - center_feature_scale = F.linear(query, - weight=center_feature_scale_proj_weight, - bias=center_feature_scale_proj_bias).sigmoid() + weight=center_feature_scale_proj_weight, + bias=center_feature_scale_proj_bias + ).sigmoid() return center_feature_scale -class DCNv3_pytorch(nn.Module): +class DCNv4_pytorch(nn.Module): def __init__( self, channels=64, kernel_size=3, - dw_kernel_size=None, stride=1, pad=1, dilation=1, group=4, offset_scale=1.0, - act_layer='GELU', - norm_layer='LN', - center_feature_scale=False): + dw_kernel_size=None, + remove_center=False, + output_bias=True, + without_pointwise=False, + **kwargs + ): """ - DCNv3 Module + DCNv4 Module :param channels :param kernel_size :param stride @@ -264,114 +309,107 @@ def __init__( :param dilation :param group :param offset_scale - :param act_layer - :param norm_layer + :param dw_kernel_size + :param remove_center + :param output_bias + :param without_pointwise """ super().__init__() if channels % group != 0: raise ValueError( f'channels must be divisible by group, but got {channels} and {group}') _d_per_group = channels // group - dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size + # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation - if not _is_power_of_2(_d_per_group): - warnings.warn( - "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " - "which is more efficient in our CUDA implementation.") + assert _d_per_group % 16 == 0 self.offset_scale = offset_scale self.channels = channels self.kernel_size = kernel_size - self.dw_kernel_size = dw_kernel_size self.stride = stride self.dilation = dilation self.pad = pad self.group = group self.group_channels = channels // group self.offset_scale = offset_scale - self.center_feature_scale = center_feature_scale - - self.dw_conv = nn.Sequential( - nn.Conv2d( - channels, - channels, - kernel_size=dw_kernel_size, - stride=1, - padding=(dw_kernel_size - 1) // 2, - groups=channels), - build_norm_layer( - channels, - norm_layer, - 'channels_first', - 'channels_last'), - build_act_layer(act_layer)) - self.offset = nn.Linear( - channels, - group * kernel_size * kernel_size * 2) - self.mask = nn.Linear( - channels, - group * kernel_size * kernel_size) - self.input_proj = nn.Linear(channels, channels) - self.output_proj = nn.Linear(channels, channels) + self.dw_kernel_size = dw_kernel_size + self.remove_center = int(remove_center) + self.without_pointwise = without_pointwise + + self.K = group * (kernel_size * kernel_size - self.remove_center) + if dw_kernel_size is not None: + self.offset_mask_dw = \ + nn.Conv2d(channels, channels, dw_kernel_size, stride=1, padding=(dw_kernel_size - 1) // 2, groups=channels) + # self.offset_mask = nn.Linear(channels, int(math.ceil((self.K * 3)/8)*8)) + self.offset = nn.Linear(channels, self.K * 2) + self.mask = nn.Linear(channels, self.K) + if not without_pointwise: + self.value_proj = nn.Linear(channels, channels) + self.output_proj = nn.Linear(channels, channels, bias=output_bias) self._reset_parameters() - self.center_feature_scale_module = CenterFeatureScaleModule() - self.center_feature_scale_proj_weight = torch.zeros((group, channels), dtype=torch.float) - self.center_feature_scale_proj_bias = torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, ) - - if center_feature_scale: - self.center_feature_scale_proj_weight = nn.Parameter( - torch.zeros((group, channels), dtype=torch.float)) - self.center_feature_scale_proj_bias = nn.Parameter( - torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) - self.center_feature_scale_module = CenterFeatureScaleModule() def _reset_parameters(self): + # constant_(self.offset_mask.weight.data, 0.) + # constant_(self.offset_mask.bias.data, 0.) constant_(self.offset.weight.data, 0.) constant_(self.offset.bias.data, 0.) constant_(self.mask.weight.data, 0.) constant_(self.mask.bias.data, 0.) - xavier_uniform_(self.input_proj.weight.data) - constant_(self.input_proj.bias.data, 0.) - xavier_uniform_(self.output_proj.weight.data) - constant_(self.output_proj.bias.data, 0.) - - def forward(self, input, shape: Optional[Tuple[int, int]]=None): + if not self.without_pointwise: + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + if self.output_proj.bias is not None: + constant_(self.output_proj.bias.data, 0.) + + def forward(self, input, shape: Optional[Tuple[int, int]] = None): """ - :param query (N, H, W, C) - :return output (N, H, W, C) + :param input (N, L, C) + :param shape (H, W) or None + :return output (N, L, C) """ - # N, H, W, _ = input.shape - N, H, W, C = input.shape - L = H * W - - x = self.input_proj(input) - x_proj = x - - x1 = input.permute(0, 3, 1, 2) - x1 = self.dw_conv(x1) - offset = self.offset(x1) - mask = self.mask(x1).reshape(N, H, W, self.group, -1) - mask = F.softmax(mask, -1).reshape(N, H, W, -1) - - x = dcnv3_core_pytorch( - x, offset, mask, - self.kernel_size, self.kernel_size, - self.stride, self.stride, - self.pad, self.pad, - self.dilation, self.dilation, - self.group, self.group_channels, - self.offset_scale) - if self.center_feature_scale: - center_feature_scale = self.center_feature_scale_module( - x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) - # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels - center_feature_scale = center_feature_scale[..., None].repeat( - 1, 1, 1, 1, self.channels // self.group).flatten(-2) - x = x * (1 - center_feature_scale) + x_proj * center_feature_scale - x = self.output_proj(x) + N, L, C = input.shape + if shape is not None: + H, W = shape + else: + H, W = int(L ** 0.5), int(L ** 0.5) + + x = input + if not self.without_pointwise: + x = self.value_proj(x) + x = x.reshape(N, H, W, -1) + if self.dw_kernel_size is not None: + offset_mask_input = self.offset_mask_dw(input.view(N, H, W, C).permute(0, 3, 1, 2)) + offset_mask_input = offset_mask_input.permute(0, 2, 3, 1).view(N, L, C) + else: + offset_mask_input = input + # offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1) + offset = self.offset(offset_mask_input).reshape(N, H, W, -1) + mask = self.mask(offset_mask_input).reshape(N, H, W, -1) + x = dcnv4_core_pytorch( + x, + offset, + mask, + self.kernel_size, + self.kernel_size, + self.stride, + self.stride, + self.pad, + self.pad, + self.dilation, + self.dilation, + self.group, + self.group_channels, + self.offset_scale + ) + + x = x.view(N, L, -1) + if not self.without_pointwise: + x = self.output_proj(x) return x - -# --- DCNv3 pure pytorch implementation finished --- # + + +# --- DCNv4 pure pytorch implementation finished --- # # --- FlashInternImage implementation start --- # class CrossAttention(nn.Module): r""" Cross Attention Module @@ -388,16 +426,17 @@ class CrossAttention(nn.Module): attn_head_dim (int, optional): Dimension of attention head. out_dim (int, optional): Dimension of output. """ - - def __init__(self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0., - proj_drop=0., - attn_head_dim=None, - out_dim=None): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + attn_head_dim=None, + out_dim=None + ): super().__init__() if out_dim is None: out_dim = dim @@ -438,17 +477,19 @@ def forward(self, x, k=None, v=None): v_bias = self.v_bias q = F.linear(input=x, weight=self.q.weight, bias=q_bias) - q = q.reshape(B, N, 1, self.num_heads, - -1).permute(2, 0, 3, 1, - 4).squeeze(0) # (B, N_head, N_q, dim) + q = q.reshape(B, N, 1, self.num_heads, -1) \ + .permute(2, 0, 3, 1, 4) \ + .squeeze(0) # (B, N_head, N_q, dim) k = F.linear(input=k, weight=self.k.weight, bias=k_bias) - k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, - 4).squeeze(0) + k = k.reshape(B, N_k, 1, self.num_heads, -1)\ + .permute(2, 0, 3, 1, 4)\ + .squeeze(0) v = F.linear(input=v, weight=self.v.weight, bias=v_bias) - v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, - 4).squeeze(0) + v = v.reshape(B, N_v, 1, self.num_heads, -1)\ + .permute(2, 0, 3, 1, 4)\ + .squeeze(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k) @@ -480,42 +521,47 @@ class AttentiveBlock(nn.Module): attn_head_dim (int, optional): Dimension of attention head. Default: None. out_dim (int, optional): Dimension of output. Default: None. """ - - def __init__(self, - dim, - num_heads, - qkv_bias=False, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer="LN", - attn_head_dim=None, - out_dim=None): + def __init__( + self, + dim, + num_heads, + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer="LN", + attn_head_dim=None, + out_dim=None + ): super().__init__() self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6) self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6) self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6) - self.cross_dcn = CrossAttention(dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - attn_head_dim=attn_head_dim, - out_dim=out_dim) - - self.drop_path = DropPath( - drop_path) if drop_path > 0. else nn.Identity() - - def forward(self, - x_q, - x_kv, - pos_q, - pos_k, - bool_masked_pos, - rel_pos_bias=None): + self.cross_dcn = \ + CrossAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + attn_head_dim=attn_head_dim, + out_dim=out_dim + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward( + self, + x_q, + x_kv, + pos_q, + pos_k, + bool_masked_pos, + rel_pos_bias=None + ): x_q = self.norm1_q(x_q + pos_q) x_k = self.norm1_k(x_kv + pos_k) x_v = self.norm1_v(x_kv) @@ -526,14 +572,18 @@ def forward(self, class AttentionPoolingBlock(AttentiveBlock): - def forward(self, x): x_q = x.mean(1, keepdim=True) x_kv = x pos_q, pos_k = 0, 0 - x = super().forward(x_q, x_kv, pos_q, pos_k, - bool_masked_pos=None, - rel_pos_bias=None) + x = super().forward( + x_q, + x_kv, + pos_q, + pos_k, + bool_masked_pos=None, + rel_pos_bias=None + ) x = x.squeeze(1) return x @@ -546,28 +596,41 @@ class StemLayer(nn.Module): act_layer (str): activation layer norm_layer (str): normalization layer """ - - def __init__(self, - in_chans=3, - out_chans=96, - act_layer='GELU', - norm_layer='BN'): + def __init__( + self, + in_chans=3, + out_chans=96, + act_layer='GELU', + norm_layer='BN' + ): super().__init__() - self.conv1 = nn.Conv2d(in_chans, - out_chans // 2, - kernel_size=3, - stride=2, - padding=1) - self.norm1 = build_norm_layer(out_chans // 2, norm_layer, - 'channels_first', 'channels_first') + self.conv1 = nn.Conv2d( + in_chans, + out_chans // 2, + kernel_size=3, + stride=2, + padding=1 + ) + self.norm1 = build_norm_layer( + out_chans // 2, + norm_layer, + 'channels_first', + 'channels_first' + ) self.act = build_act_layer(act_layer) - self.conv2 = nn.Conv2d(out_chans // 2, - out_chans, - kernel_size=3, - stride=2, - padding=1) - self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first', - 'channels_last') + self.conv2 = nn.Conv2d( + out_chans // 2, + out_chans, + kernel_size=3, + stride=2, + padding=1 + ) + self.norm2 = build_norm_layer( + out_chans, + norm_layer, + 'channels_first', + 'channels_last' + ) def forward(self, x): x = self.conv1(x) @@ -584,39 +647,34 @@ class DownsampleLayer(nn.Module): channels (int): number of input channels norm_layer (str): normalization layer """ - def __init__(self, channels, norm_layer='LN'): super().__init__() - self.conv = nn.Conv2d(channels, - 2 * channels, - kernel_size=3, - stride=2, - padding=1, - bias=False) - self.norm = build_norm_layer(2 * channels, norm_layer, - 'channels_first', 'channels_first') - self.dcn_version = dcn_version - + self.conv = nn.Conv2d( + channels, + 2 * channels, + kernel_size=3, + stride=2, + padding=1, + bias=False + ) + self.norm = build_norm_layer( + 2 * channels, + norm_layer, + 'channels_first', + 'channels_first' + ) def forward(self, x, shape: Optional[Tuple[int, int]]=None): - if self.dcn_version == 'DCNv4': - N, HW, C = x.shape - if shape is not None: - H, W = shape - else: - H, W = int(HW**0.5), int(HW**0.5) + N, HW, C = x.shape + if shape is not None: + H, W = shape else: - N, H, W, C = x.shape - HW = H * W + H, W = int(HW ** 0.5), int(HW ** 0.5) x = x.view(N, H, W, C) x = self.conv(x.permute(0, 3, 1, 2)) x = self.norm(x) # B C H W - if self.dcn_version == 'DCNv4': - H, W = x.size(2), x.size(3) - x = x.flatten(2).permute(0, 2, 1) - else: - x = x.permute(0, 2, 3, 1) - H, W = x.size(1), x.size(2) + H, W = x.size(2), x.size(3) + x = x.flatten(2).permute(0, 2, 1) return x, (H, W) @@ -628,16 +686,18 @@ class MLPLayer(nn.Module): hidden_features (int): number of hidden features out_features (int): number of output features act_layer (str): activation layer + mlp_fc2_bias (bool): whether to use mlp fc2 bias drop (float): dropout rate """ - - def __init__(self, - in_features, - hidden_features=None, - out_features=None, - act_layer='GELU', - mlp_fc2_bias=False, - drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer='GELU', + mlp_fc2_bias=False, + drop=0. + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -646,7 +706,6 @@ def __init__(self, self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias) self.drop = nn.Dropout(drop) - def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): x = self.fc1(x) x = self.act(x) @@ -674,28 +733,25 @@ class InternImageLayer(nn.Module): dcn_output_bias (bool): whether to use dcn output bias, Default: False. mlp_fc2_bias (bool): whether to use mlp fc2 bias, Default: False. dw_kernel_size (int): Size of the dwconv, Default: None. - res_post_norm (bool): whether to use res post normalization, Default: False. - center_feature_scale (bool): whether to use center feature scale, Default: False. """ - - def __init__(self, - core_op, - channels, - groups, - mlp_ratio=4., - drop=0., - drop_path=0., - act_layer='GELU', - norm_layer='LN', - post_norm=False, - layer_scale=None, - offset_scale=1.0, - with_cp=False, - dcn_output_bias=False, - mlp_fc2_bias=False, - dw_kernel_size=None, # for InternImage-H/G - res_post_norm=False, # for InternImage-H/G - center_feature_scale=False): # for InternImage-H/G + def __init__( + self, + core_op, + channels, + groups, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer='GELU', + norm_layer='LN', + post_norm=False, + layer_scale=None, + offset_scale=1.0, + with_cp=False, + dcn_output_bias=False, + mlp_fc2_bias=False, + dw_kernel_size=None + ): super().__init__() self.channels = channels self.groups = groups @@ -704,7 +760,7 @@ def __init__(self, self.norm1 = build_norm_layer(channels, 'LN') self.post_norm = post_norm - if dcn_version == 'DCNv4' and core_op == 'DCNv4': + if dcn_version == 'CUDA' and core_op == 'DCNv4': self.dcn = DCNv4.DCNv4( channels=channels, group=groups, @@ -713,48 +769,41 @@ def __init__(self, output_bias=dcn_output_bias, ) else: - self.dcn = DCNv3_pytorch( + self.dcn = DCNv4_pytorch( channels=channels, group=groups, offset_scale=offset_scale, dw_kernel_size=dw_kernel_size, - center_feature_scale=center_feature_scale + output_bias=dcn_output_bias, ) - self.drop_path = DropPath(drop_path) if drop_path > 0. \ - else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = build_norm_layer(channels, 'LN') - self.mlp = MLPLayer(in_features=channels, - hidden_features=int(channels * mlp_ratio), - act_layer=act_layer, - drop=drop, - mlp_fc2_bias=mlp_fc2_bias - ) + self.mlp = MLPLayer( + in_features=channels, + hidden_features=int(channels * mlp_ratio), + act_layer=act_layer, + drop=drop, + mlp_fc2_bias=mlp_fc2_bias + ) self.layer_scale = layer_scale is not None self.gamma1 = torch.ones(channels) self.gamma2 = torch.ones(channels) if self.layer_scale: - self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels), - requires_grad=True) - self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels), - requires_grad=True) - - self.res_post_norm = res_post_norm - self.res_post_norm1 = nn.Sequential() - self.res_post_norm2 = nn.Sequential() - if res_post_norm: - self.res_post_norm1 = build_norm_layer(channels, 'LN') - self.res_post_norm2 = build_norm_layer(channels, 'LN') + self.gamma1 = nn.Parameter( + layer_scale * torch.ones(channels), + requires_grad=True + ) + self.gamma2 = nn.Parameter( + layer_scale * torch.ones(channels), + requires_grad=True + ) def _inner_forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int): if not self.layer_scale: if self.post_norm: x = x + self.drop_path(self.norm1(self.dcn(x, shape))) x = x + self.drop_path(self.norm2(self.mlp(x, shape, level_idx))) - elif self.res_post_norm: # for InternImage-H/G - x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x), shape))) - x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x), shape, level_idx))) - else: x = x + self.drop_path(self.dcn(self.norm1(x), shape)) x = x + self.drop_path(self.mlp(self.norm2(x), shape, level_idx)) @@ -803,37 +852,32 @@ class InternImageBlock(nn.Module): mlp_fc2_bias (bool): whether to use mlp fc2 bias, Default: False. dw_kernel_size (int): Size of the dwconv, Default: None. post_norm_block_ids (list): block ids for post normalization, Default: None. - res_post_norm (bool): whether to use res post normalization, Default: False. - center_feature_scale (bool): whether to use center feature scale, Default: False. """ - - def __init__(self, - core_op, - channels, - depth, - groups, - downsample=True, - downsample_layer=DownsampleLayer, - mlp_ratio=4., - drop=0., - drop_path=0., - act_layer='GELU', - norm_layer='LN', - post_norm=False, - offset_scale=0.5, - layer_scale=None, - with_cp=False, - dcn_output_bias=False, - mlp_fc2_bias=False, - dw_kernel_size=None, # for InternImage-H/G - post_norm_block_ids: Optional[List[int]]=None, # for InternImage-H/G - res_post_norm=False, # for InternImage-H/G - center_feature_scale=False): # for InternImage-H/G + def __init__( + self, + core_op, + channels, + depth, + groups, + downsample=True, + downsample_layer=DownsampleLayer, + mlp_ratio=4., + drop=0., + drop_path=0., + act_layer='GELU', + norm_layer='LN', + post_norm=False, + offset_scale=0.5, + layer_scale=None, + with_cp=False, + dcn_output_bias=False, + mlp_fc2_bias=False, + dw_kernel_size=None, + ): super().__init__() self.channels = channels self.depth = depth self.post_norm = post_norm - self.center_feature_scale = center_feature_scale self.grad_checkpoint = False self.blocks = nn.ModuleList([ @@ -843,8 +887,7 @@ def __init__(self, groups=groups, mlp_ratio=mlp_ratio, drop=drop, - drop_path=drop_path[i] if isinstance( - drop_path, list) else drop_path, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, act_layer=act_layer, norm_layer=norm_layer, post_norm=post_norm, @@ -853,32 +896,18 @@ def __init__(self, with_cp=with_cp, dcn_output_bias=dcn_output_bias, mlp_fc2_bias=mlp_fc2_bias, - dw_kernel_size=dw_kernel_size, # for InternImage-H/G - res_post_norm=res_post_norm, # for InternImage-H/G - center_feature_scale=center_feature_scale # for InternImage-H/G + dw_kernel_size=dw_kernel_size, ) for i in range(depth) ]) self.norm = nn.Sequential() - if not self.post_norm or center_feature_scale: + if not self.post_norm: self.norm = build_norm_layer(channels, 'LN') - - self.if_post_norm = post_norm_block_ids is not None - self.post_norm_block_ids: List[int] = [0] - self.post_norms = nn.ModuleList() - if post_norm_block_ids is not None: # for InternImage-H/G - self.post_norm_block_ids = post_norm_block_ids - self.post_norms = nn.ModuleList( - [build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids] - ) - self.downsample = downsample_layer( - channels=channels, norm_layer=norm_layer) if downsample else None - @torch.jit.ignore - def _forward_post_norm(self, x, i: int): - index = self.post_norm_block_ids.index(i) - x = self.post_norms[index](x) # for InternImage-H/G - return x + self.downsample = downsample_layer( + channels=channels, + norm_layer=norm_layer + ) if downsample else None @torch.jit.ignore def forward_return_wo_downsample(self, x, shape: Optional[Tuple[int, int]]=None, level_idx: int=0): @@ -887,9 +916,8 @@ def forward_return_wo_downsample(self, x, shape: Optional[Tuple[int, int]]=None, x = checkpoint_seq(blk, x) else: x = blk(x, shape=shape, level_idx=level_idx) - if self.if_post_norm and (i in self.post_norm_block_ids): - self._forward_post_norm(x, i) - if not self.post_norm or self.center_feature_scale: + + if not self.post_norm: x = self.norm(x) x_ = x.clone() @@ -905,9 +933,8 @@ def forward_shape(self, x, shape: Tuple[int, int], level_idx: int=0): x = checkpoint_seq(blk, x) else: x = blk(x, shape=shape, level_idx=level_idx) - if self.if_post_norm and (i in self.post_norm_block_ids): - self._forward_post_norm(x, i) - if not self.post_norm or self.center_feature_scale: + + if not self.post_norm: x = self.norm(x) if self.downsample is not None: @@ -916,19 +943,23 @@ def forward_shape(self, x, shape: Tuple[int, int], level_idx: int=0): return x, shape def forward(self, x, shape: Optional[Tuple[int, int]]=None, level_idx: int=0): + N, H, W, C = x.shape + x = x.view(N, H * W, C) for i, blk in enumerate(self.blocks): if self.grad_checkpoint and not torch.jit.is_scripting(): x = checkpoint_seq(blk, x) else: x = blk(x, shape=shape, level_idx=level_idx) - if self.if_post_norm and (i in self.post_norm_block_ids): - self._forward_post_norm(x, i) - if not self.post_norm or self.center_feature_scale: + + if not self.post_norm: x = self.norm(x) if self.downsample is not None: x, shape = self.downsample(x, shape=shape) + if shape is not None: + H, W = shape + x = x.view(N, H, W, -1) return x @@ -936,7 +967,7 @@ class FlashInternImage(nn.Module): r""" FlashInternImage A PyTorch impl based on : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` - - https://arxiv.org/pdf/2103.14030 + https://arxiv.org/pdf/2211.05778 `DCNv4` - https://arxiv.org/pdf/2401.06197 Args: core_op (str): Core operator. Default: 'DCNv4' @@ -959,51 +990,41 @@ class FlashInternImage(nn.Module): dcn_output_bias (bool): Whether to use dcn output bias. Default: False dw_kernel_size (int): Size of the dwconv. Default: None global_pool (str): Global pooling type. Default: 'avg' - use_clip_projector (bool): Whether to use clip projector. Default: False - level2_post_norm (bool): Whether to use level2 post norm. Default: False - level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None - res_post_norm (bool): Whether to use res post norm. Default: False - center_feature_scale (bool): Whether to use center feature scale. Default: False out_indices (tuple): Output from which stages. Default: (0, 1, 2, 3) """ - - def __init__(self, - core_op='DCNv4', - channels=64, - depths=[3, 4, 18, 5], - groups=[3, 6, 12, 24], - num_classes=1000, - mlp_ratio=4., - drop_rate=0., - drop_path_rate=0.2, - drop_path_type='linear', - act_layer='GELU', - norm_layer='LN', - layer_scale=None, - offset_scale=0.5, - post_norm=False, - cls_scale=1.5, - with_cp=False, - mlp_fc2_bias=False, - dcn_output_bias=False, - dw_kernel_size=None, - global_pool='avg', - use_clip_projector=False, # for InternImage-H/G - level2_post_norm=False, # for InternImage-H/G - level2_post_norm_block_ids=None, # for InternImage-H/G - res_post_norm=False, # for InternImage-H/G - center_feature_scale=False, # for InternImage-H/G - out_indices=(0, 1, 2, 3), - **kwargs): + def __init__( + self, + core_op='DCNv4', + channels=64, + depths=[3, 4, 18, 5], + groups=[3, 6, 12, 24], + num_classes=1000, + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.2, + drop_path_type='linear', + act_layer='GELU', + norm_layer='LN', + layer_scale=None, + offset_scale=0.5, + post_norm=False, + cls_scale=1.5, + with_cp=False, + mlp_fc2_bias=False, + dcn_output_bias=False, + dw_kernel_size=None, + global_pool='avg', + out_indices=(0, 1, 2, 3), + **kwargs + ): super().__init__() - self.dcn_version = dcn_version - if dcn_version == 'DCNv4': + if dcn_version == 'CUDA': core_op = 'DCNv4' else: - warnings.warn('FlashInternImage requires DCNv4, but not found in current enviroment.\n\ - By default using DCNv3 pure pytorch implementation instead, which will affect the performance.\n\ - Suggesting install DCNv4 by `pip install DCNv4`') - core_op = 'DCNv3' + warnings.warn('FlashInternImage requires CUDA version of DCNv4, but not found in current enviroment.\n\ + By default using DCNv4 pure pytorch implementation instead, which will affect the performance.\n\ + Suggesting install DCNv4 CUDA version by `pip install DCNv4`') + core_op = 'DCNv4' self.core_op = core_op self.num_classes = num_classes self.num_levels = len(depths) @@ -1013,8 +1034,6 @@ def __init__(self, self.post_norm = post_norm self.mlp_ratio = mlp_ratio self.act_layer = act_layer - self.use_clip_projector = use_clip_projector - self.level2_post_norm_block_ids = level2_post_norm_block_ids self.out_indices = out_indices self.output_fmt = 'NHWC' self.feature_info = [] @@ -1022,15 +1041,14 @@ def __init__(self, _logger.info(f'using activation layer: {act_layer}') _logger.info(f'using main norm layer: {norm_layer}') _logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}') - _logger.info(f'level2_post_norm: {level2_post_norm}') - _logger.info(f'level2_post_norm_block_ids: {level2_post_norm_block_ids}') - _logger.info(f'res_post_norm: {res_post_norm}') in_chans = 3 - self.patch_embed = StemLayer(in_chans=in_chans, - out_chans=channels, - act_layer=act_layer, - norm_layer=norm_layer) + self.patch_embed = StemLayer( + in_chans=in_chans, + out_chans=channels, + act_layer=act_layer, + norm_layer=norm_layer + ) self.pos_drop = nn.Dropout(p=drop_rate) self.feature_info.append(dict(num_chs=channels, reduction=2, module='patch_embed')) @@ -1043,9 +1061,6 @@ def __init__(self, self.levels = nn.Sequential() for i in range(self.num_levels): - post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and ( - i == 2) else None # for InternImage-H/G - level = InternImageBlock( core_op=core_op, channels=int(channels * 2**i), @@ -1064,53 +1079,42 @@ def __init__(self, with_cp=with_cp, mlp_fc2_bias=mlp_fc2_bias, dcn_output_bias=dcn_output_bias, - dw_kernel_size=dw_kernel_size, # for InternImage-H/G - post_norm_block_ids=post_norm_block_ids, # for InternImage-H/G - res_post_norm=res_post_norm, # for InternImage-H/G - center_feature_scale=center_feature_scale # for InternImage-H/G + dw_kernel_size=dw_kernel_size, ) self.levels.add_module(str(i), level) if i < self.num_levels - 1: self.feature_info.append( - dict(num_chs=int(channels * 2 ** (i + 1)), reduction=2 ** (i + 2), module=f'levels.{i}')) + dict( + num_chs=int(channels * 2 ** (i + 1)), + reduction=2 ** (i + 2), + module=f'levels.{i}' + )) else: self.feature_info.append( - dict(num_chs=int(channels * 2 ** i), reduction=2 ** (i + 1), module=f'levels.{i}')) + dict( + num_chs=int(channels * 2 ** i), + reduction=2 ** (i + 1), + module=f'levels.{i}' + )) - if not use_clip_projector: # for InternImage-T/S/B/L/XL - self.conv_head = nn.Sequential( - nn.Conv2d(self.num_features, - int(self.num_features * cls_scale), - kernel_size=1, - bias=False), - build_norm_layer(int(self.num_features * cls_scale), 'BN', - 'channels_first', 'channels_first'), - build_act_layer(act_layer)) - self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \ - if num_classes > 0 else nn.Identity() - else: # for InternImage-H/G - pretrain_embed_dim, _stride, attnpool_num_heads, clip_embed_dim = 1024, 2, 16, 768 - self.dcnv3_head_x4 = nn.Sequential( - nn.Conv2d(in_channels=self.num_features, - out_channels=pretrain_embed_dim * (_stride ** 2), - kernel_size=1), nn.PixelShuffle(_stride)) - self.dcnv3_head_x3 = nn.Conv2d(in_channels=self.num_features // 2, - out_channels=pretrain_embed_dim, - kernel_size=1) - self.clip_projector = AttentionPoolingBlock( - dim=pretrain_embed_dim, - num_heads=attnpool_num_heads, - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - norm_layer=norm_layer, - out_dim=clip_embed_dim) - self.fc_norm = build_norm_layer(clip_embed_dim, norm_layer, eps=1e-6) - self.head = nn.Linear( - clip_embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.conv_head = nn.Sequential( + nn.Conv2d( + self.num_features, + int(self.num_features * cls_scale), + kernel_size=1, + bias=False + ), + build_norm_layer( + int(self.num_features * cls_scale), + 'BN', + 'channels_first', + 'channels_first' + ), + build_act_layer(act_layer) + ) + self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \ + if num_classes > 0 else nn.Identity() - # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.pool_type = global_pool self.global_pool = SelectAdaptivePool2d(output_size=(1, 1), pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool != '' else nn.Identity() @@ -1128,9 +1132,9 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) def _init_deform_weights(self, m): - if dcn_version == 'DCNv4' and isinstance(m, getattr(DCNv4, self.core_op)): + if dcn_version == 'CUDA' and isinstance(m, getattr(DCNv4, self.core_op)): m._reset_parameters() - elif isinstance(m, DCNv3_pytorch): + elif isinstance(m, DCNv4_pytorch): m._reset_parameters() @torch.jit.ignore @@ -1146,13 +1150,20 @@ def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes if num_classes == 0: self.conv_head = nn.Sequential( - nn.Conv2d(self.num_features, - int(self.num_features), - kernel_size=1, - bias=False), - build_norm_layer(int(self.num_features), 'BN', - 'channels_first', 'channels_first'), - build_act_layer(self.act_layer)) + nn.Conv2d( + self.num_features, + int(self.num_features), + kernel_size=1, + bias=False + ), + build_norm_layer( + int(self.num_features), + 'BN', + 'channels_first', + 'channels_first' + ), + build_act_layer(self.act_layer) + ) self.global_pool = SelectAdaptivePool2d(output_size=(1, 1), pool_type=global_pool) self.pool_type = global_pool @@ -1199,14 +1210,11 @@ def lr_decay_keywards(self, decay_ratio=0.87): lr_ratios["levels.2.norm"] = lr_ratios['levels.3.blocks.0.'] return lr_ratios - def forward_features_no_clip_projector(self, x): + def forward_features(self, x): x = self.patch_embed(x) N, H, W, C = x.shape - - if self.dcn_version == 'DCNv4': - x = x.view(N, H * W, C) - - shape=(H, W) + x = x.view(N, H * W, C) + shape = (H, W) for level_idx, level in enumerate(self.levels): # old_shape = shape x, shape = level.forward_shape(x, shape=shape) @@ -1219,48 +1227,19 @@ def forward_features_no_clip_projector(self, x): @torch.jit.ignore def forward_features_seq_out(self, x): # for detection or segmentation - x = self.patch_embed(x) + x = self.patch_embed(x) N, H, W, C = x.shape - if self.dcn_version == 'DCNv4': - x = x.view(N, H * W, C) - shape=(H, W) + x = x.view(N, H * W, C) + shape = (H, W) seq_out = [] for level_idx, level in enumerate(self.levels): old_shape = shape x, x_ , shape = level.forward_return_wo_downsample(x, shape=shape, level_idx=level_idx) - h, w= old_shape + h, w = old_shape seq_out.append(x_.reshape(N, h, w, -1).permute(0, 3, 1, 2)) return seq_out - @torch.jit.ignore - def forward_clip_projector(self, x): # for InternImage-H/G - xs = self.forward_features_seq_out(x) - x1, x2, x3, x4 = xs - - x1 = x1.permute(0, 3, 1, 2) # NHWC -> NCHW - x2 = x2.permute(0, 3, 1, 2) # NHWC -> NCHW - x3 = x3.permute(0, 3, 1, 2) # NHWC -> NCHW - x4 = x4.permute(0, 3, 1, 2) # NHWC -> NCHW - - x4 = self.dcnv3_head_x4(x4) - x = x4 - x3 = self.dcnv3_head_x3(x3) - x = x + x3 - - # x = x.flatten(-2).transpose(1, 2).contiguous() - # x = self.clip_projector(x) - # x = self.fc_norm(x) - - return x - - def forward_features(self, x): - if self.use_clip_projector: # for InternImage-H/G - x = self.forward_clip_projector(x) - else: # for InternImage-T/S/B/L/XL - x = self.forward_features_no_clip_projector(x) - return x - - def forward_head_no_clip_projector(self, x): + def forward_head(self, x): x = self.conv_head(x.permute(0, 3, 1, 2)) x = self.global_pool(x) x = self.flatten(x) @@ -1268,21 +1247,6 @@ def forward_head_no_clip_projector(self, x): x = x.permute(0, 2, 3, 1) x = self.head(x) return x - - @torch.jit.ignore - def forward_head_clip_projector(self, x): - x = x.flatten(-2).transpose(1, 2).contiguous() - x = self.clip_projector(x) - x = self.fc_norm(x) - x = self.head(x) - return x - - def forward_head(self, x): - if self.use_clip_projector: - x = self.forward_head_clip_projector(x) - else: - x = self.forward_head_no_clip_projector(x) - return x def forward(self, x): x = self.forward_features(x) @@ -1350,6 +1314,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_filename='flash_intern_image_l_22kto1k_384.pth', input_size=(3, 384, 384), pool_size=(12, 12), + num_classes=21841, ), 'flash_intern_image_large.384_in22k': _cfg( url='https://huggingface.co/OpenGVLab/DCNv4/blob/main/flash_intern_image_l_22k_384.pth', @@ -1497,10 +1462,10 @@ def _create_flash_intern_image(variant: str, pretrained: bool = False, **kwargs) def _check_pretrained_available(pretrained: bool): - if dcn_version == 'DCNv4': + if dcn_version == 'CUDA': return pretrained - warnings.warn('DCNv4 is not installed, cannot load pretrained weights') + warnings.warn('CUDA version of DCNv4 is not installed, cannot load pretrained weights') return False From 9a52ad57fb6c764afca0ece2e478b0c23d00a307 Mon Sep 17 00:00:00 2001 From: Pig Date: Mon, 13 May 2024 11:19:21 +0800 Subject: [PATCH 09/13] Rename some module and remove some unused variables of FlashInternImage --- timm/models/flash_intern_image.py | 133 +++++++++++++++--------------- 1 file changed, 67 insertions(+), 66 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index 3e1cc79a1e..bfc086087c 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -706,7 +706,7 @@ def __init__( self.fc2 = nn.Linear(hidden_features, out_features, bias=mlp_fc2_bias) self.drop = nn.Dropout(drop) - def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): + def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) @@ -715,8 +715,8 @@ def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): return x -class InternImageLayer(nn.Module): - r""" Basic layer of InternImage +class InternImageBlock(nn.Module): + r""" Basic Block of InternImage Args: core_op (str): core operation of InternImage channels (int): number of input channels @@ -799,43 +799,43 @@ def __init__( requires_grad=True ) - def _inner_forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int): + def _inner_forward(self, x, shape: Optional[Tuple[int, int]]): if not self.layer_scale: if self.post_norm: x = x + self.drop_path(self.norm1(self.dcn(x, shape))) - x = x + self.drop_path(self.norm2(self.mlp(x, shape, level_idx))) + x = x + self.drop_path(self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.dcn(self.norm1(x), shape)) - x = x + self.drop_path(self.mlp(self.norm2(x), shape, level_idx)) + x = x + self.drop_path(self.mlp(self.norm2(x))) return x if self.post_norm: x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x, shape))) - x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x, shape, level_idx))) + x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x))) else: x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x), shape)) - x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x), shape, level_idx)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) return x @torch.jit.ignore - def forward_checkpoint(self, x, shape: Optional[Tuple[int, int]], level_idx: int = 0): - x = checkpoint.checkpoint(self._inner_forward, x, shape, level_idx) + def forward_checkpoint(self, x, shape: Optional[Tuple[int, int]]): + x = checkpoint.checkpoint(self._inner_forward, x, shape) return x - def forward(self, x, shape: Optional[Tuple[int, int]], level_idx: int=0): + def forward(self, x, shape: Optional[Tuple[int, int]]): if self.with_cp: # - x = self.forward_checkpoint(x, shape, level_idx) + x = self.forward_checkpoint(x, shape) else: - x = self._inner_forward(x, shape, level_idx) + x = self._inner_forward(x, shape) return x -class InternImageBlock(nn.Module): - r""" Block of InternImage +class InternImageStage(nn.Module): + r""" Stage of InternImage Args: core_op (str): core operation of InternImage channels (int): number of input channels - depth (int): Depth of each block. + depth (int): Depth of each stage. groups (int): Groups of each block. downsample (bool): Whether to use downsample, Default: True. downsample_layer (nn.Module): Downsample layer, Default: DownsampleLayer. @@ -851,7 +851,6 @@ class InternImageBlock(nn.Module): dcn_output_bias (bool): whether to use dcn output bias, Default: False. mlp_fc2_bias (bool): whether to use mlp fc2 bias, Default: False. dw_kernel_size (int): Size of the dwconv, Default: None. - post_norm_block_ids (list): block ids for post normalization, Default: None. """ def __init__( self, @@ -881,7 +880,7 @@ def __init__( self.grad_checkpoint = False self.blocks = nn.ModuleList([ - InternImageLayer( + InternImageBlock( core_op=core_op, channels=channels, groups=groups, @@ -910,12 +909,12 @@ def __init__( ) if downsample else None @torch.jit.ignore - def forward_return_wo_downsample(self, x, shape: Optional[Tuple[int, int]]=None, level_idx: int=0): + def forward_return_wo_downsample(self, x, shape: Optional[Tuple[int, int]]=None): for i, blk in enumerate(self.blocks): if self.grad_checkpoint and not torch.jit.is_scripting(): x = checkpoint_seq(blk, x) else: - x = blk(x, shape=shape, level_idx=level_idx) + x = blk(x, shape=shape) if not self.post_norm: x = self.norm(x) @@ -927,12 +926,12 @@ def forward_return_wo_downsample(self, x, shape: Optional[Tuple[int, int]]=None, return x, x_, shape - def forward_shape(self, x, shape: Tuple[int, int], level_idx: int=0): + def forward_shape(self, x, shape: Tuple[int, int]): for i, blk in enumerate(self.blocks): if self.grad_checkpoint and not torch.jit.is_scripting(): x = checkpoint_seq(blk, x) else: - x = blk(x, shape=shape, level_idx=level_idx) + x = blk(x, shape=shape) if not self.post_norm: x = self.norm(x) @@ -942,14 +941,14 @@ def forward_shape(self, x, shape: Tuple[int, int], level_idx: int=0): return x, shape - def forward(self, x, shape: Optional[Tuple[int, int]]=None, level_idx: int=0): + def forward(self, x, shape: Optional[Tuple[int, int]]=None): N, H, W, C = x.shape x = x.view(N, H * W, C) for i, blk in enumerate(self.blocks): if self.grad_checkpoint and not torch.jit.is_scripting(): x = checkpoint_seq(blk, x) else: - x = blk(x, shape=shape, level_idx=level_idx) + x = blk(x, shape=shape) if not self.post_norm: x = self.norm(x) @@ -1027,10 +1026,10 @@ def __init__( core_op = 'DCNv4' self.core_op = core_op self.num_classes = num_classes - self.num_levels = len(depths) + self.num_stages = len(depths) self.depths = depths self.channels = channels - self.num_features = int(channels * 2**(self.num_levels - 1)) + self.num_features = int(channels * 2**(self.num_stages - 1)) self.post_norm = post_norm self.mlp_ratio = mlp_ratio self.act_layer = act_layer @@ -1059,9 +1058,9 @@ def __init__( for i in range(len(dpr)): dpr[i] = drop_path_rate - self.levels = nn.Sequential() - for i in range(self.num_levels): - level = InternImageBlock( + self.stages = nn.Sequential() + for i in range(self.num_stages): + stage = InternImageStage( core_op=core_op, channels=int(channels * 2**i), depth=depths[i], @@ -1072,7 +1071,7 @@ def __init__( act_layer=act_layer, norm_layer=norm_layer, post_norm=post_norm, - downsample=(i < self.num_levels - 1), + downsample=(i < self.num_stages - 1), downsample_layer = DownsampleLayer, layer_scale=layer_scale, offset_scale=offset_scale, @@ -1081,20 +1080,20 @@ def __init__( dcn_output_bias=dcn_output_bias, dw_kernel_size=dw_kernel_size, ) - self.levels.add_module(str(i), level) - if i < self.num_levels - 1: + self.stages.add_module(str(i), stage) + if i < self.num_stages - 1: self.feature_info.append( dict( num_chs=int(channels * 2 ** (i + 1)), reduction=2 ** (i + 2), - module=f'levels.{i}' + module=f'stages.{i}' )) else: self.feature_info.append( dict( num_chs=int(channels * 2 ** i), reduction=2 ** (i + 1), - module=f'levels.{i}' + module=f'stages.{i}' )) self.conv_head = nn.Sequential( @@ -1175,12 +1174,12 @@ def reset_classifier(self, num_classes, global_pool='avg'): def group_matcher(self, coarse: bool = False) -> Dict: return dict( stem=r'^patch_embed', # stem and embed - blocks=[(r'^levels\.(\d+)', None)] + blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)' ) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): - for l in self.levels: + for l in self.stages: l.grad_checkpointing = enable @torch.jit.ignore @@ -1193,21 +1192,21 @@ def lr_decay_keywards(self, decay_ratio=0.87): layer_num = 3 - i # 3 2 1 0 for j in range(self.depths[layer_num]): block_num = self.depths[layer_num] - j - 1 - tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num) + tag = 'stages.{}.blocks.{}.'.format(layer_num, block_num) decay = 1.0 * (decay_ratio**idx) lr_ratios[tag] = decay idx += 1 # patch_embed (before stage-1) - lr_ratios["patch_embed"] = lr_ratios['levels.0.blocks.0.'] - # levels.0.downsample (between stage-1 and stage-2) - lr_ratios["levels.0.downsample"] = lr_ratios['levels.1.blocks.0.'] - lr_ratios["levels.0.norm"] = lr_ratios['levels.1.blocks.0.'] - # levels.1.downsample (between stage-2 and stage-3) - lr_ratios["levels.1.downsample"] = lr_ratios['levels.2.blocks.0.'] - lr_ratios["levels.1.norm"] = lr_ratios['levels.2.blocks.0.'] - # levels.2.downsample (between stage-3 and stage-4) - lr_ratios["levels.2.downsample"] = lr_ratios['levels.3.blocks.0.'] - lr_ratios["levels.2.norm"] = lr_ratios['levels.3.blocks.0.'] + lr_ratios["patch_embed"] = lr_ratios['stages.0.blocks.0.'] + # stages.0.downsample (between stage-1 and stage-2) + lr_ratios["stages.0.downsample"] = lr_ratios['stages.1.blocks.0.'] + lr_ratios["stages.0.norm"] = lr_ratios['stages.1.blocks.0.'] + # stages.1.downsample (between stage-2 and stage-3) + lr_ratios["stages.1.downsample"] = lr_ratios['stages.2.blocks.0.'] + lr_ratios["stages.1.norm"] = lr_ratios['stages.2.blocks.0.'] + # stages.2.downsample (between stage-3 and stage-4) + lr_ratios["stages.2.downsample"] = lr_ratios['stages.3.blocks.0.'] + lr_ratios["stages.2.norm"] = lr_ratios['stages.3.blocks.0.'] return lr_ratios def forward_features(self, x): @@ -1215,9 +1214,9 @@ def forward_features(self, x): N, H, W, C = x.shape x = x.view(N, H * W, C) shape = (H, W) - for level_idx, level in enumerate(self.levels): + for _, stage in enumerate(self.stages): # old_shape = shape - x, shape = level.forward_shape(x, shape=shape) + x, shape = stage.forward_shape(x, shape=shape) h, w = shape x = x.view(N, h, w, -1) # x = self.conv_head(x) @@ -1232,9 +1231,9 @@ def forward_features_seq_out(self, x): # for detection or segmentation x = x.view(N, H * W, C) shape = (H, W) seq_out = [] - for level_idx, level in enumerate(self.levels): + for stage_idx, stage in enumerate(self.stages): old_shape = shape - x, x_ , shape = level.forward_return_wo_downsample(x, shape=shape, level_idx=level_idx) + x, x_ , shape = stage.forward_return_wo_downsample(x, shape=shape) h, w = old_shape seq_out.append(x_.reshape(N, h, w, -1).permute(0, 3, 1, 2)) return seq_out @@ -1267,6 +1266,8 @@ def checkpoint_filter_fn(state_dict, model): for k, v in _state_dict.items(): if k.startswith('backbone.'): k = k[9:] + if k.startswith('levels.'): + k[:7] = 'stages.' state_dict[k] = v if list(state_dict.keys())[0].startswith('module.'): @@ -1591,7 +1592,7 @@ def dino_4scale_flash_intern_image_tiny(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.2, - pose_norm=False, + post_norm=False, with_cp=True, output_indices=(1, 2, 3), ) @@ -1613,7 +1614,7 @@ def dino_4scale_flash_intern_image_small(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, output_indices=(1, 2, 3), @@ -1636,7 +1637,7 @@ def dino_4scale_flash_intern_image_base(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, output_indices=(1, 2, 3), @@ -1659,7 +1660,7 @@ def dino_4scale_flash_intern_image_large(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.4, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, dcn_output_bias=True, @@ -1684,7 +1685,7 @@ def mask_rcnn_flash_intern_image_tiny(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.2, - pose_norm=False, + post_norm=False, with_cp=True, output_indices=(0, 1, 2, 3), ) @@ -1706,7 +1707,7 @@ def mask_rcnn_flash_intern_image_small(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, output_indices=(0, 1, 2, 3), @@ -1729,7 +1730,7 @@ def mask_rcnn_flash_intern_image_base(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, output_indices=(0, 1, 2, 3), @@ -1752,7 +1753,7 @@ def mask2former_flash_intern_image_tiny(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.2, - pose_norm=False, + post_norm=False, with_cp=False, output_indices=(0, 1, 2, 3), ) @@ -1774,7 +1775,7 @@ def mask2former_flash_intern_image_small(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=False, dw_kernel_size=3, output_indices=(0, 1, 2, 3), @@ -1797,7 +1798,7 @@ def mask2former_flash_intern_image_base(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.4, - pose_norm=True, + post_norm=True, with_cp=False, dw_kernel_size=3, output_indices=(0, 1, 2, 3), @@ -1820,7 +1821,7 @@ def mask2former_flash_intern_image_large(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.5, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, dcn_output_bias=True, @@ -1845,7 +1846,7 @@ def upernet_flash_intern_image_tiny(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.2, - pose_norm=False, + post_norm=False, with_cp=True, output_indices=(0, 1, 2, 3), ) @@ -1867,7 +1868,7 @@ def upernet_flash_intern_image_small(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=True, dw_kernel_size=3, output_indices=(0, 1, 2, 3), @@ -1890,7 +1891,7 @@ def upernet_flash_intern_image_base(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.3, - pose_norm=True, + post_norm=True, with_cp=False, dw_kernel_size=3, output_indices=(0, 1, 2, 3), @@ -1913,7 +1914,7 @@ def upernet_flash_intern_image_large(pretrained=False, **kwargs): layer_scale=1., mlp_ratio=4., drop_path_rate=0.4, - pose_norm=True, + post_norm=True, with_cp=False, dw_kernel_size=3, dcn_output_bias=True, From c947f1d59d45c7000e9de0dc77b018568dae3e71 Mon Sep 17 00:00:00 2001 From: Pig Date: Mon, 13 May 2024 11:35:32 +0800 Subject: [PATCH 10/13] Optimize code impl of CrossAttention --- timm/models/flash_intern_image.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index bfc086087c..211277ea57 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -448,18 +448,9 @@ def __init__( self.scale = qk_scale or head_dim ** -0.5 assert all_head_dim == dim - self.q = nn.Linear(dim, all_head_dim, bias=False) - self.k = nn.Linear(dim, all_head_dim, bias=False) - self.v = nn.Linear(dim, all_head_dim, bias=False) - - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) - else: - self.q_bias = None - self.k_bias = None - self.v_bias = None + self.q = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.k = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.v = nn.Linear(dim, all_head_dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, out_dim) @@ -470,23 +461,17 @@ def forward(self, x, k=None, v=None): N_k = k.shape[1] N_v = v.shape[1] - q_bias, k_bias, v_bias = None, None, None - if self.q_bias is not None: - q_bias = self.q_bias - k_bias = self.k_bias - v_bias = self.v_bias - - q = F.linear(input=x, weight=self.q.weight, bias=q_bias) + q = self.q(x) q = q.reshape(B, N, 1, self.num_heads, -1) \ .permute(2, 0, 3, 1, 4) \ .squeeze(0) # (B, N_head, N_q, dim) - k = F.linear(input=k, weight=self.k.weight, bias=k_bias) + k = self.k(k) k = k.reshape(B, N_k, 1, self.num_heads, -1)\ .permute(2, 0, 3, 1, 4)\ .squeeze(0) - v = F.linear(input=v, weight=self.v.weight, bias=v_bias) + v = self.v(v) v = v.reshape(B, N_v, 1, self.num_heads, -1)\ .permute(2, 0, 3, 1, 4)\ .squeeze(0) @@ -941,9 +926,12 @@ def forward_shape(self, x, shape: Tuple[int, int]): return x, shape + # duplicated implementation of forward_shape inside forward to avoid torchscript error + # and the forward function is to allow feature extraction def forward(self, x, shape: Optional[Tuple[int, int]]=None): N, H, W, C = x.shape x = x.view(N, H * W, C) + for i, blk in enumerate(self.blocks): if self.grad_checkpoint and not torch.jit.is_scripting(): x = checkpoint_seq(blk, x) From dee97c937d775fde7981d81c1af7d6da35ab080d Mon Sep 17 00:00:00 2001 From: Pig Date: Mon, 13 May 2024 11:43:43 +0800 Subject: [PATCH 11/13] Fix bug of checkpoint_filter_fn of FlashInternImage --- timm/models/flash_intern_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index 211277ea57..51de49be74 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -1255,7 +1255,7 @@ def checkpoint_filter_fn(state_dict, model): if k.startswith('backbone.'): k = k[9:] if k.startswith('levels.'): - k[:7] = 'stages.' + k = 'stages.' + k[7:] state_dict[k] = v if list(state_dict.keys())[0].startswith('module.'): From 3c9b302a11dce1ad42f59071688c2ec099900d4b Mon Sep 17 00:00:00 2001 From: Pig Date: Sun, 19 May 2024 14:35:49 +0800 Subject: [PATCH 12/13] Update impl of DCNv4_pytorch --- timm/models/flash_intern_image.py | 55 +++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index 51de49be74..909cf44a79 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -29,6 +29,7 @@ from typing import Dict, Any, Tuple, Optional, List import warnings import logging +import math __all__ = ['FlashInternImage'] @@ -40,6 +41,12 @@ import DCNv4 except ImportError: dcn_version = 'pytorch' + +has_yacs = True +try: + import yacs +except ImportError: + has_yacs = False class to_channels_first(nn.Module): @@ -340,21 +347,21 @@ def __init__( if dw_kernel_size is not None: self.offset_mask_dw = \ nn.Conv2d(channels, channels, dw_kernel_size, stride=1, padding=(dw_kernel_size - 1) // 2, groups=channels) - # self.offset_mask = nn.Linear(channels, int(math.ceil((self.K * 3)/8)*8)) - self.offset = nn.Linear(channels, self.K * 2) - self.mask = nn.Linear(channels, self.K) + self.offset_mask = nn.Linear(channels, int(math.ceil((self.K * 3)/8)*8)) + # self.offset = nn.Linear(channels, self.K * 2) + # self.mask = nn.Linear(channels, self.K) if not without_pointwise: self.value_proj = nn.Linear(channels, channels) self.output_proj = nn.Linear(channels, channels, bias=output_bias) self._reset_parameters() def _reset_parameters(self): - # constant_(self.offset_mask.weight.data, 0.) - # constant_(self.offset_mask.bias.data, 0.) - constant_(self.offset.weight.data, 0.) - constant_(self.offset.bias.data, 0.) - constant_(self.mask.weight.data, 0.) - constant_(self.mask.bias.data, 0.) + constant_(self.offset_mask.weight.data, 0.) + constant_(self.offset_mask.bias.data, 0.) + # constant_(self.offset.weight.data, 0.) + # constant_(self.offset.bias.data, 0.) + # constant_(self.mask.weight.data, 0.) + # constant_(self.mask.bias.data, 0.) if not self.without_pointwise: xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) @@ -383,9 +390,11 @@ def forward(self, input, shape: Optional[Tuple[int, int]] = None): offset_mask_input = offset_mask_input.permute(0, 2, 3, 1).view(N, L, C) else: offset_mask_input = input - # offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1) - offset = self.offset(offset_mask_input).reshape(N, H, W, -1) - mask = self.mask(offset_mask_input).reshape(N, H, W, -1) + offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1) + offset = offset_mask[:, :, :, :self.K * 2] + mask = offset_mask[:, :, :, self.K * 2: self.K * 3] + # offset = self.offset(offset_mask_input).reshape(N, H, W, -1) + # mask = self.mask(offset_mask_input).reshape(N, H, W, -1) x = dcnv4_core_pytorch( x, offset, @@ -1002,6 +1011,7 @@ def __init__( dw_kernel_size=None, global_pool='avg', out_indices=(0, 1, 2, 3), + show_model_info=False, **kwargs ): super().__init__() @@ -1024,10 +1034,19 @@ def __init__( self.out_indices = out_indices self.output_fmt = 'NHWC' self.feature_info = [] - _logger.info(f'use core type: {core_op}') - _logger.info(f'using activation layer: {act_layer}') - _logger.info(f'using main norm layer: {norm_layer}') - _logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}') + if show_model_info: + _logger.info(f'use core type: {core_op}') + _logger.info(f'num_classes: {num_classes}') + _logger.info(f'num_stages: {self.num_stages}') + _logger.info(f'depths: {depths}') + _logger.info(f'groups: {groups}') + _logger.info(f'channels: {channels}') + _logger.info(f'num_features: {self.num_features}') + _logger.info(f'mlp_ratio: {mlp_ratio}') + _logger.info(f'drop_rate: {drop_rate}') + _logger.info(f'using activation layer: {act_layer}') + _logger.info(f'using main norm layer: {norm_layer}') + _logger.info(f'using dpr: {drop_path_type}, {drop_path_rate}') in_chans = 3 self.patch_embed = StemLayer( @@ -1451,10 +1470,10 @@ def _create_flash_intern_image(variant: str, pretrained: bool = False, **kwargs) def _check_pretrained_available(pretrained: bool): - if dcn_version == 'CUDA': + if has_yacs: return pretrained - warnings.warn('CUDA version of DCNv4 is not installed, cannot load pretrained weights') + warnings.warn('Current pretrained weights need `yacs` to load, but not found in current enviroment.\n') return False From b70b40a2e031b573bdb7e238703ed3a8edd7a020 Mon Sep 17 00:00:00 2001 From: Pig Date: Sun, 19 May 2024 19:46:17 +0800 Subject: [PATCH 13/13] Fix bugs of DCNv4_pytorch module --- timm/models/flash_intern_image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/flash_intern_image.py b/timm/models/flash_intern_image.py index 909cf44a79..ffe973c61a 100644 --- a/timm/models/flash_intern_image.py +++ b/timm/models/flash_intern_image.py @@ -343,6 +343,7 @@ def __init__( self.remove_center = int(remove_center) self.without_pointwise = without_pointwise + self.P = int(kernel_size * kernel_size - self.remove_center) self.K = group * (kernel_size * kernel_size - self.remove_center) if dw_kernel_size is not None: self.offset_mask_dw = \ @@ -391,8 +392,10 @@ def forward(self, input, shape: Optional[Tuple[int, int]] = None): else: offset_mask_input = input offset_mask = self.offset_mask(offset_mask_input).reshape(N, H, W, -1) - offset = offset_mask[:, :, :, :self.K * 2] - mask = offset_mask[:, :, :, self.K * 2: self.K * 3] + offset_mask_no_pad = offset_mask[:, :, :, : self.K * 3] + offset_mask_no_pad = offset_mask_no_pad.unflatten(-1, (self.group, self.P * 3)) + offset = offset_mask_no_pad[:, :, :, :, : self.P * 2].flatten(-2) + mask = offset_mask_no_pad[:, :, :, :, self.P * 2: self.P * 3].flatten(-2) # offset = self.offset(offset_mask_input).reshape(N, H, W, -1) # mask = self.mask(offset_mask_input).reshape(N, H, W, -1) x = dcnv4_core_pytorch(