Skip to content

Commit cec7290

Browse files
committed
Final rope-vit update, mrope back to rope_mixed for clarity, upload weights to hub, add attribution
1 parent 9393921 commit cec7290

File tree

3 files changed

+59
-62
lines changed

3 files changed

+59
-62
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ All model architecture families include variants with pretrained weights. There
508508
* Res2Net - https://arxiv.org/abs/1904.01169
509509
* ResNeSt - https://arxiv.org/abs/2004.08955
510510
* ReXNet - https://arxiv.org/abs/2007.00992
511+
* ROPE-ViT - https://arxiv.org/abs/2403.13298
511512
* SelecSLS - https://arxiv.org/abs/1907.00837
512513
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
513514
* Sequencer2D - https://arxiv.org/abs/2205.01972

timm/layers/pos_embed_sincos.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,9 @@ class RotaryEmbeddingMixed(nn.Module):
555555
556556
This implementation supports mixed (learnable) ROPE. In mixed mode,
557557
each transformer block has its own set of learnable frequency parameters.
558+
559+
Based on 'Rotary Position Embedding for Vision: https://arxiv.org/abs/2403.13298)'
560+
Compatible with original at https://github.com/naver-ai/rope-vit
558561
"""
559562
def __init__(
560563
self,

timm/models/eva.py

Lines changed: 55 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,20 @@
2626
year={2025}
2727
}
2828
29+
@inproceedings{heo2024rotary,
30+
title={Rotary position embedding for vision transformer},
31+
author={Heo, Byeongho and Park, Song and Han, Dongyoon and Yun, Sangdoo},
32+
booktitle={European Conference on Computer Vision},
33+
pages={289--305},
34+
year={2024},
35+
organization={Springer}
36+
}
37+
2938
This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions:
3039
* EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py.
3140
* `timm` original SBB ViT w/ ROPE position embeddings
3241
* Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)
42+
* ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298)
3343
3444
Modifications by / Copyright 2023 Ross Wightman, original copyrights below
3545
"""
@@ -773,7 +783,7 @@ def forward_intermediates(
773783
else:
774784
blocks = self.blocks[:max_index + 1]
775785
# Handle depth-dependent embeddings for mixed mode
776-
if self.rope_mixed and rot_pos_embed is not None:
786+
if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None:
777787
for i, blk in enumerate(blocks):
778788
if self.grad_checkpointing and not torch.jit.is_scripting():
779789
x = checkpoint(blk, x, rope=rot_pos_embed[i])
@@ -850,7 +860,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
850860
x = self.norm_pre(x)
851861

852862
# Handle depth-dependent embeddings for mixed mode
853-
if self.rope_mixed and rot_pos_embed is not None:
863+
if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None:
854864
# rot_pos_embed has shape (depth, H*W, dim) for mixed mode
855865
for i, blk in enumerate(self.blocks):
856866
if self.grad_checkpointing and not torch.jit.is_scripting():
@@ -991,23 +1001,6 @@ def checkpoint_filter_fn(
9911001
state_dict = state_dict.get('module', state_dict)
9921002
state_dict = state_dict.get('state_dict', state_dict)
9931003

994-
# FIXME remove after conversion, check if this is a rope-vit checkpoint
995-
if 'freqs' in state_dict:
996-
# Handle rope-vit specific conversions
997-
for k, v in state_dict.items():
998-
# Skip rope-vit specific buffers
999-
if any([kk in k for kk in ('freqs_t_x', 'freqs_t_y')]):
1000-
continue
1001-
# Handle mixed mode frequency parameters
1002-
if k == 'freqs':
1003-
# Check if model uses mixed mode by looking at other keys or freqs shape
1004-
# Mixed mode has learnable freqs, axial mode doesn't use them
1005-
k = 'rope.freqs'
1006-
model_shape = model.state_dict().get(k).shape
1007-
v = v.reshape(model_shape)
1008-
out_dict[k] = v
1009-
return out_dict
1010-
10111004
# Loading Meta PE (Perception Encoder) weights
10121005
if 'visual.conv1.weight' in state_dict:
10131006
return _convert_pe(state_dict, model)
@@ -1031,7 +1024,7 @@ def checkpoint_filter_fn(
10311024
continue
10321025
k = k[len_prefix:]
10331026

1034-
if 'rope' in k:
1027+
if 'rope' in k and not k == 'rope.freqs':
10351028
# fixed embedding no need to load buffer from checkpoint
10361029
continue
10371030

@@ -1375,76 +1368,76 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
13751368

13761369
# RoPE-ViT models from Naver
13771370
'vit_small_patch16_rope_224.naver_in1k': _cfg(
1378-
hf_hub_id='naver-ai/rope_axial_deit_small_patch16_LS',
1379-
hf_hub_filename='pytorch_model.bin',
1371+
hf_hub_id='timm/',
13801372
mean=IMAGENET_DEFAULT_MEAN,
13811373
std=IMAGENET_DEFAULT_STD,
1374+
license='apache-2.0',
13821375
),
13831376
'vit_base_patch16_rope_224.naver_in1k': _cfg(
1384-
hf_hub_id='naver-ai/rope_axial_deit_base_patch16_LS',
1385-
hf_hub_filename='pytorch_model.bin',
1377+
hf_hub_id='timm/',
13861378
mean=IMAGENET_DEFAULT_MEAN,
13871379
std=IMAGENET_DEFAULT_STD,
1380+
license='apache-2.0',
13881381
),
13891382
'vit_large_patch16_rope_224.naver_in1k': _cfg(
1390-
hf_hub_id='naver-ai/rope_axial_deit_large_patch16_LS',
1391-
hf_hub_filename='pytorch_model.bin',
1383+
hf_hub_id='timm/',
13921384
mean=IMAGENET_DEFAULT_MEAN,
13931385
std=IMAGENET_DEFAULT_STD,
1386+
license='apache-2.0',
13941387
),
1395-
'vit_small_patch16_mrope_224.naver_in1k': _cfg(
1396-
hf_hub_id='naver-ai/rope_mixed_deit_small_patch16_LS',
1397-
hf_hub_filename='pytorch_model.bin',
1388+
'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg(
1389+
hf_hub_id='timm/',
13981390
mean=IMAGENET_DEFAULT_MEAN,
13991391
std=IMAGENET_DEFAULT_STD,
1392+
license='apache-2.0',
14001393
),
1401-
'vit_base_patch16_mrope_224.naver_in1k': _cfg(
1402-
hf_hub_id='naver-ai/rope_mixed_deit_base_patch16_LS',
1403-
hf_hub_filename='pytorch_model.bin',
1394+
'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg(
1395+
hf_hub_id='timm/',
14041396
mean=IMAGENET_DEFAULT_MEAN,
14051397
std=IMAGENET_DEFAULT_STD,
1398+
license='apache-2.0',
14061399
),
1407-
'vit_large_patch16_mrope_224.naver_in1k': _cfg(
1408-
hf_hub_id='naver-ai/rope_mixed_deit_large_patch16_LS',
1409-
hf_hub_filename='pytorch_model.bin',
1400+
'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg(
1401+
hf_hub_id='timm/',
14101402
mean=IMAGENET_DEFAULT_MEAN,
14111403
std=IMAGENET_DEFAULT_STD,
1404+
license='apache-2.0',
14121405
),
14131406
'vit_small_patch16_rope_ape_224.naver_in1k': _cfg(
1414-
hf_hub_id='naver-ai/rope_axial_ape_deit_small_patch16_LS',
1415-
hf_hub_filename='pytorch_model.bin',
1407+
hf_hub_id='timm/',
14161408
mean=IMAGENET_DEFAULT_MEAN,
14171409
std=IMAGENET_DEFAULT_STD,
1410+
license='apache-2.0',
14181411
),
14191412
'vit_base_patch16_rope_ape_224.naver_in1k': _cfg(
1420-
hf_hub_id='naver-ai/rope_axial_ape_deit_base_patch16_LS',
1421-
hf_hub_filename='pytorch_model.bin',
1413+
hf_hub_id='timm/',
14221414
mean=IMAGENET_DEFAULT_MEAN,
14231415
std=IMAGENET_DEFAULT_STD,
1416+
license='apache-2.0',
14241417
),
14251418
'vit_large_patch16_rope_ape_224.naver_in1k': _cfg(
1426-
hf_hub_id='naver-ai/rope_axial_ape_deit_large_patch16_LS',
1427-
hf_hub_filename='pytorch_model.bin',
1419+
hf_hub_id='timm/',
14281420
mean=IMAGENET_DEFAULT_MEAN,
14291421
std=IMAGENET_DEFAULT_STD,
1422+
license='apache-2.0',
14301423
),
1431-
'vit_small_patch16_mrope_ape_224.naver_in1k': _cfg(
1432-
hf_hub_id='naver-ai/rope_mixed_ape_deit_small_patch16_LS',
1433-
hf_hub_filename='pytorch_model.bin',
1424+
'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
1425+
hf_hub_id='timm/',
14341426
mean=IMAGENET_DEFAULT_MEAN,
14351427
std=IMAGENET_DEFAULT_STD,
1428+
license='apache-2.0',
14361429
),
1437-
'vit_base_patch16_mrope_ape_224.naver_in1k': _cfg(
1438-
hf_hub_id='naver-ai/rope_mixed_ape_deit_base_patch16_LS',
1439-
hf_hub_filename='pytorch_model.bin',
1430+
'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
1431+
hf_hub_id='timm/',
14401432
mean=IMAGENET_DEFAULT_MEAN,
14411433
std=IMAGENET_DEFAULT_STD,
1434+
license='apache-2.0',
14421435
),
1443-
'vit_large_patch16_mrope_ape_224.naver_in1k': _cfg(
1444-
hf_hub_id='naver-ai/rope_mixed_ape_deit_large_patch16_LS',
1445-
hf_hub_filename='pytorch_model.bin',
1436+
'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
1437+
hf_hub_id='timm/',
14461438
mean=IMAGENET_DEFAULT_MEAN,
14471439
std=IMAGENET_DEFAULT_STD,
1440+
license='apache-2.0',
14481441
),
14491442
})
14501443

@@ -2023,7 +2016,7 @@ def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
20232016

20242017

20252018
@register_model
2026-
def vit_small_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
2019+
def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
20272020
"""RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit"""
20282021
model_args = dict(
20292022
patch_size=16,
@@ -2042,12 +2035,12 @@ def vit_small_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20422035
rope_temperature=10.0,
20432036
rope_mixed_mode=True,
20442037
)
2045-
model = _create_eva('vit_small_patch16_mrope_224', pretrained=pretrained, **dict(model_args, **kwargs))
2038+
model = _create_eva('vit_small_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
20462039
return model
20472040

20482041

20492042
@register_model
2050-
def vit_base_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
2043+
def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
20512044
"""RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit"""
20522045
model_args = dict(
20532046
patch_size=16,
@@ -2066,12 +2059,12 @@ def vit_base_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20662059
rope_temperature=10.0,
20672060
rope_mixed_mode=True,
20682061
)
2069-
model = _create_eva('vit_base_patch16_mrope_224', pretrained=pretrained, **dict(model_args, **kwargs))
2062+
model = _create_eva('vit_base_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
20702063
return model
20712064

20722065

20732066
@register_model
2074-
def vit_large_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
2067+
def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
20752068
"""RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit"""
20762069
model_args = dict(
20772070
patch_size=16,
@@ -2090,7 +2083,7 @@ def vit_large_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20902083
rope_temperature=10.0,
20912084
rope_mixed_mode=True,
20922085
)
2093-
model = _create_eva('vit_large_patch16_mrope_224', pretrained=pretrained, **dict(model_args, **kwargs))
2086+
model = _create_eva('vit_large_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
20942087
return model
20952088

20962089

@@ -2170,7 +2163,7 @@ def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21702163

21712164

21722165
@register_model
2173-
def vit_small_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2166+
def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21742167
"""RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
21752168
model_args = dict(
21762169
patch_size=16,
@@ -2191,12 +2184,12 @@ def vit_small_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21912184
rope_mixed_mode=True,
21922185
)
21932186

2194-
model = _create_eva('vit_small_patch16_mrope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
2187+
model = _create_eva('vit_small_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
21952188
return model
21962189

21972190

21982191
@register_model
2199-
def vit_base_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2192+
def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22002193
"""RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
22012194
model_args = dict(
22022195
patch_size=16,
@@ -2216,12 +2209,12 @@ def vit_base_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22162209
rope_temperature=10.0,
22172210
rope_mixed_mode=True,
22182211
)
2219-
model = _create_eva('vit_base_patch16_mrope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
2212+
model = _create_eva('vit_base_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
22202213
return model
22212214

22222215

22232216
@register_model
2224-
def vit_large_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2217+
def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22252218
"""RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
22262219
model_args = dict(
22272220
patch_size=16,
@@ -2241,6 +2234,6 @@ def vit_large_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22412234
rope_temperature=10.0,
22422235
rope_mixed_mode=True,
22432236
)
2244-
model = _create_eva('vit_large_patch16_mrope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
2237+
model = _create_eva('vit_large_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
22452238
return model
22462239

0 commit comments

Comments
 (0)