Skip to content

Commit 94e13b7

Browse files
authored
Merge pull request #2529 from huggingface/rope_vit
Adding Naver rope-vit compatibility to EVA ViT
2 parents 8d41071 + cec7290 commit 94e13b7

16 files changed

+873
-167
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ All model architecture families include variants with pretrained weights. There
508508
* Res2Net - https://arxiv.org/abs/1904.01169
509509
* ResNeSt - https://arxiv.org/abs/2004.08955
510510
* ReXNet - https://arxiv.org/abs/2007.00992
511+
* ROPE-ViT - https://arxiv.org/abs/2403.13298
511512
* SelecSLS - https://arxiv.org/abs/1907.00837
512513
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
513514
* Sequencer2D - https://arxiv.org/abs/2205.01972

timm/layers/__init__.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,40 @@
1+
from ._fx import (
2+
create_feature_extractor,
3+
get_graph_node_names,
4+
register_notrace_function,
5+
register_notrace_module,
6+
is_notrace_module,
7+
is_notrace_function,
8+
get_notrace_modules,
9+
get_notrace_functions,
10+
)
111
from .activations import *
2-
from .adaptive_avgmax_pool import \
3-
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
12+
from .adaptive_avgmax_pool import (
13+
adaptive_avgmax_pool2d,
14+
select_adaptive_pool2d,
15+
AdaptiveAvgMaxPool2d,
16+
SelectAdaptivePool2d,
17+
)
418
from .attention import Attention, AttentionRope, maybe_add_mask
519
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
620
from .attention_pool import AttentionPoolLatent
721
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
822
from .blur_pool import BlurPool2d, create_aa
923
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
1024
from .cond_conv2d import CondConv2d, get_condconv_initializer
11-
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
12-
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn, \
13-
set_reentrant_ckpt, use_reentrant_ckpt
25+
from .config import (
26+
is_exportable,
27+
is_scriptable,
28+
is_no_jit,
29+
use_fused_attn,
30+
set_exportable,
31+
set_scriptable,
32+
set_no_jit,
33+
set_layer_config,
34+
set_fused_attn,
35+
set_reentrant_ckpt,
36+
use_reentrant_ckpt,
37+
)
1438
from .conv2d_same import Conv2dSame, conv2d_same
1539
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
1640
from .create_act import create_act_layer, get_act_layer, get_act_fn
@@ -20,8 +44,17 @@
2044
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
2145
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
2246
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
23-
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
24-
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
47+
from .evo_norm import (
48+
EvoNorm2dB0,
49+
EvoNorm2dB1,
50+
EvoNorm2dB2,
51+
EvoNorm2dS0,
52+
EvoNorm2dS0a,
53+
EvoNorm2dS1,
54+
EvoNorm2dS1a,
55+
EvoNorm2dS2,
56+
EvoNorm2dS2a,
57+
)
2558
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
2659
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
2760
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
@@ -37,19 +70,50 @@
3770
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
3871
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
3972
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
40-
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
41-
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
73+
from .norm_act import (
74+
BatchNormAct2d,
75+
GroupNormAct,
76+
GroupNorm1Act,
77+
LayerNormAct,
78+
LayerNormAct2d,
79+
SyncBatchNormAct,
80+
convert_sync_batchnorm,
81+
FrozenBatchNormAct2d,
82+
freeze_batch_norm_2d,
83+
unfreeze_batch_norm_2d,
84+
)
4285
from .padding import get_padding, get_same_padding, pad_same
4386
from .patch_dropout import PatchDropout
4487
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
4588
from .pool1d import global_pool_nlc
4689
from .pool2d_same import AvgPool2dSame, create_pool2d
4790
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
48-
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
49-
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple, resize_rel_pos_bias_table_levit
50-
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
51-
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
52-
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
91+
from .pos_embed_rel import (
92+
RelPosMlp,
93+
RelPosBias,
94+
RelPosBiasTf,
95+
gen_relative_position_index,
96+
gen_relative_log_coords,
97+
resize_rel_pos_bias_table,
98+
resize_rel_pos_bias_table_simple,
99+
resize_rel_pos_bias_table_levit,
100+
)
101+
from .pos_embed_sincos import (
102+
pixel_freq_bands,
103+
freq_bands,
104+
build_sincos2d_pos_embed,
105+
build_fourier_pos_embed,
106+
build_rotary_pos_embed,
107+
apply_rot_embed,
108+
apply_rot_embed_cat,
109+
apply_rot_embed_list,
110+
apply_keep_indices_nlc,
111+
FourierEmbed,
112+
RotaryEmbedding,
113+
RotaryEmbeddingCat,
114+
RotaryEmbeddingMixed,
115+
get_mixed_freqs,
116+
)
53117
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
54118
from .selective_kernel import SelectiveKernel
55119
from .separable_conv import SeparableConv2d, SeparableConvNormAct
@@ -60,5 +124,11 @@
60124
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
61125
from .trace_utils import _assert, _float_to_int
62126
from .typing import LayerType, PadType
63-
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
64-
init_weight_jax, init_weight_vit
127+
from .weight_init import (
128+
trunc_normal_,
129+
trunc_normal_tf_,
130+
variance_scaling_,
131+
lecun_normal_,
132+
init_weight_jax,
133+
init_weight_vit,
134+
)

timm/layers/_fx.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
2+
3+
import torch
4+
from torch import nn
5+
6+
try:
7+
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
8+
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
9+
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
10+
has_fx_feature_extraction = True
11+
except ImportError:
12+
has_fx_feature_extraction = False
13+
14+
15+
__all__ = [
16+
'register_notrace_module',
17+
'is_notrace_module',
18+
'get_notrace_modules',
19+
'register_notrace_function',
20+
'is_notrace_function',
21+
'get_notrace_functions',
22+
'create_feature_extractor',
23+
'get_graph_node_names',
24+
]
25+
26+
# modules to treat as leafs when tracing
27+
_leaf_modules = set()
28+
29+
30+
def register_notrace_module(module: Type[nn.Module]):
31+
"""
32+
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
33+
"""
34+
_leaf_modules.add(module)
35+
return module
36+
37+
38+
def is_notrace_module(module: Type[nn.Module]):
39+
return module in _leaf_modules
40+
41+
42+
def get_notrace_modules():
43+
return list(_leaf_modules)
44+
45+
46+
# Functions we want to autowrap (treat them as leaves)
47+
_autowrap_functions = set()
48+
49+
50+
def register_notrace_function(name_or_fn):
51+
_autowrap_functions.add(name_or_fn)
52+
return name_or_fn
53+
54+
55+
def is_notrace_function(func: Callable):
56+
return func in _autowrap_functions
57+
58+
59+
def get_notrace_functions():
60+
return list(_autowrap_functions)
61+
62+
63+
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
64+
return _get_graph_node_names(
65+
model,
66+
tracer_kwargs={
67+
'leaf_modules': list(_leaf_modules),
68+
'autowrap_functions': list(_autowrap_functions)
69+
}
70+
)
71+
72+
73+
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
74+
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
75+
return _create_feature_extractor(
76+
model, return_nodes,
77+
tracer_kwargs={
78+
'leaf_modules': list(_leaf_modules),
79+
'autowrap_functions': list(_autowrap_functions)
80+
}
81+
)

timm/layers/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from torch import nn as nn
55
from torch.nn import functional as F
66

7+
from ._fx import register_notrace_function
78
from .config import use_fused_attn
89
from .pos_embed_sincos import apply_rot_embed_cat
910

1011

12+
@torch.fx.wrap
13+
@register_notrace_function
1114
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
1215
return scores if attn_mask is None else scores + attn_mask
1316

timm/layers/attention_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from. config import use_fused_attn
15+
from .config import use_fused_attn
1616
from .helpers import to_2tuple
1717
from .pos_embed import resample_abs_pos_embed
1818
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding

timm/layers/cond_conv2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch import nn as nn
1313
from torch.nn import functional as F
1414

15+
from ._fx import register_notrace_module
1516
from .helpers import to_2tuple
1617
from .conv2d_same import conv2d_same
1718
from .padding import get_padding_value
@@ -30,6 +31,7 @@ def condconv_initializer(weight):
3031
return condconv_initializer
3132

3233

34+
@register_notrace_module
3335
class CondConv2d(nn.Module):
3436
""" Conditionally Parameterized Convolution
3537
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py

timm/layers/conv2d_same.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn.functional as F
88
from typing import Tuple, Optional
99

10+
from ._fx import register_notrace_module
1011
from .config import is_exportable, is_scriptable
1112
from .padding import pad_same, pad_same_arg, get_padding_value
1213

@@ -27,6 +28,7 @@ def conv2d_same(
2728
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
2829

2930

31+
@register_notrace_module
3032
class Conv2dSame(nn.Conv2d):
3133
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
3234
"""

timm/layers/inplace_abn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def inplace_abn(x, weight, bias, running_mean, running_var,
1515
def inplace_abn_sync(**kwargs):
1616
inplace_abn(**kwargs)
1717

18+
from ._fx import register_notrace_module
1819

20+
21+
@register_notrace_module
1922
class InplaceAbn(nn.Module):
2023
"""Activated Batch Normalization
2124

timm/layers/non_local_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch import nn
99
from torch.nn import functional as F
1010

11+
from ._fx import register_notrace_module
1112
from .conv_bn_act import ConvNormAct
1213
from .helpers import make_divisible
1314
from .trace_utils import _assert
@@ -69,6 +70,7 @@ def reset_parameters(self):
6970
nn.init.constant_(m.bias, 0)
7071

7172

73+
@register_notrace_module
7274
class BilinearAttnTransform(nn.Module):
7375

7476
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):

timm/layers/norm_act.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn import functional as F
2020
from torchvision.ops.misc import FrozenBatchNorm2d
2121

22+
from ._fx import register_notrace_module
2223
from .create_act import create_act_layer
2324
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d
2425
from .norm import RmsNorm, RmsNorm2d
@@ -39,6 +40,7 @@ def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
3940
return nn.Identity() if act is None else act
4041

4142

43+
@register_notrace_module
4244
class BatchNormAct2d(nn.BatchNorm2d):
4345
"""BatchNorm + Activation
4446
@@ -134,6 +136,7 @@ def forward(self, x):
134136
return x
135137

136138

139+
@register_notrace_module
137140
class SyncBatchNormAct(nn.SyncBatchNorm):
138141
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
139142
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
@@ -191,6 +194,7 @@ def convert_sync_batchnorm(module, process_group=None):
191194
return module_output
192195

193196

197+
@register_notrace_module
194198
class FrozenBatchNormAct2d(torch.nn.Module):
195199
"""
196200
BatchNormAct2d where the batch statistics and the affine parameters are fixed

0 commit comments

Comments
 (0)