26
26
year={2025}
27
27
}
28
28
29
+ @inproceedings{heo2024rotary,
30
+ title={Rotary position embedding for vision transformer},
31
+ author={Heo, Byeongho and Park, Song and Han, Dongyoon and Yun, Sangdoo},
32
+ booktitle={European Conference on Computer Vision},
33
+ pages={289--305},
34
+ year={2024},
35
+ organization={Springer}
36
+ }
37
+
29
38
This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions:
30
39
* EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py.
31
40
* `timm` original SBB ViT w/ ROPE position embeddings
32
41
* Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)
42
+ * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298)
33
43
34
44
Modifications by / Copyright 2023 Ross Wightman, original copyrights below
35
45
"""
@@ -773,7 +783,7 @@ def forward_intermediates(
773
783
else :
774
784
blocks = self .blocks [:max_index + 1 ]
775
785
# Handle depth-dependent embeddings for mixed mode
776
- if self . rope_mixed and rot_pos_embed is not None :
786
+ if getattr ( self , ' rope_mixed' , False ) and rot_pos_embed is not None :
777
787
for i , blk in enumerate (blocks ):
778
788
if self .grad_checkpointing and not torch .jit .is_scripting ():
779
789
x = checkpoint (blk , x , rope = rot_pos_embed [i ])
@@ -850,7 +860,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
850
860
x = self .norm_pre (x )
851
861
852
862
# Handle depth-dependent embeddings for mixed mode
853
- if self . rope_mixed and rot_pos_embed is not None :
863
+ if getattr ( self , ' rope_mixed' , False ) and rot_pos_embed is not None :
854
864
# rot_pos_embed has shape (depth, H*W, dim) for mixed mode
855
865
for i , blk in enumerate (self .blocks ):
856
866
if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -991,23 +1001,6 @@ def checkpoint_filter_fn(
991
1001
state_dict = state_dict .get ('module' , state_dict )
992
1002
state_dict = state_dict .get ('state_dict' , state_dict )
993
1003
994
- # FIXME remove after conversion, check if this is a rope-vit checkpoint
995
- if 'freqs' in state_dict :
996
- # Handle rope-vit specific conversions
997
- for k , v in state_dict .items ():
998
- # Skip rope-vit specific buffers
999
- if any ([kk in k for kk in ('freqs_t_x' , 'freqs_t_y' )]):
1000
- continue
1001
- # Handle mixed mode frequency parameters
1002
- if k == 'freqs' :
1003
- # Check if model uses mixed mode by looking at other keys or freqs shape
1004
- # Mixed mode has learnable freqs, axial mode doesn't use them
1005
- k = 'rope.freqs'
1006
- model_shape = model .state_dict ().get (k ).shape
1007
- v = v .reshape (model_shape )
1008
- out_dict [k ] = v
1009
- return out_dict
1010
-
1011
1004
# Loading Meta PE (Perception Encoder) weights
1012
1005
if 'visual.conv1.weight' in state_dict :
1013
1006
return _convert_pe (state_dict , model )
@@ -1031,7 +1024,7 @@ def checkpoint_filter_fn(
1031
1024
continue
1032
1025
k = k [len_prefix :]
1033
1026
1034
- if 'rope' in k :
1027
+ if 'rope' in k and not k == 'rope.freqs' :
1035
1028
# fixed embedding no need to load buffer from checkpoint
1036
1029
continue
1037
1030
@@ -1375,76 +1368,76 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1375
1368
1376
1369
# RoPE-ViT models from Naver
1377
1370
'vit_small_patch16_rope_224.naver_in1k' : _cfg (
1378
- hf_hub_id = 'naver-ai/rope_axial_deit_small_patch16_LS' ,
1379
- hf_hub_filename = 'pytorch_model.bin' ,
1371
+ hf_hub_id = 'timm/' ,
1380
1372
mean = IMAGENET_DEFAULT_MEAN ,
1381
1373
std = IMAGENET_DEFAULT_STD ,
1374
+ license = 'apache-2.0' ,
1382
1375
),
1383
1376
'vit_base_patch16_rope_224.naver_in1k' : _cfg (
1384
- hf_hub_id = 'naver-ai/rope_axial_deit_base_patch16_LS' ,
1385
- hf_hub_filename = 'pytorch_model.bin' ,
1377
+ hf_hub_id = 'timm/' ,
1386
1378
mean = IMAGENET_DEFAULT_MEAN ,
1387
1379
std = IMAGENET_DEFAULT_STD ,
1380
+ license = 'apache-2.0' ,
1388
1381
),
1389
1382
'vit_large_patch16_rope_224.naver_in1k' : _cfg (
1390
- hf_hub_id = 'naver-ai/rope_axial_deit_large_patch16_LS' ,
1391
- hf_hub_filename = 'pytorch_model.bin' ,
1383
+ hf_hub_id = 'timm/' ,
1392
1384
mean = IMAGENET_DEFAULT_MEAN ,
1393
1385
std = IMAGENET_DEFAULT_STD ,
1386
+ license = 'apache-2.0' ,
1394
1387
),
1395
- 'vit_small_patch16_mrope_224.naver_in1k' : _cfg (
1396
- hf_hub_id = 'naver-ai/rope_mixed_deit_small_patch16_LS' ,
1397
- hf_hub_filename = 'pytorch_model.bin' ,
1388
+ 'vit_small_patch16_rope_mixed_224.naver_in1k' : _cfg (
1389
+ hf_hub_id = 'timm/' ,
1398
1390
mean = IMAGENET_DEFAULT_MEAN ,
1399
1391
std = IMAGENET_DEFAULT_STD ,
1392
+ license = 'apache-2.0' ,
1400
1393
),
1401
- 'vit_base_patch16_mrope_224.naver_in1k' : _cfg (
1402
- hf_hub_id = 'naver-ai/rope_mixed_deit_base_patch16_LS' ,
1403
- hf_hub_filename = 'pytorch_model.bin' ,
1394
+ 'vit_base_patch16_rope_mixed_224.naver_in1k' : _cfg (
1395
+ hf_hub_id = 'timm/' ,
1404
1396
mean = IMAGENET_DEFAULT_MEAN ,
1405
1397
std = IMAGENET_DEFAULT_STD ,
1398
+ license = 'apache-2.0' ,
1406
1399
),
1407
- 'vit_large_patch16_mrope_224.naver_in1k' : _cfg (
1408
- hf_hub_id = 'naver-ai/rope_mixed_deit_large_patch16_LS' ,
1409
- hf_hub_filename = 'pytorch_model.bin' ,
1400
+ 'vit_large_patch16_rope_mixed_224.naver_in1k' : _cfg (
1401
+ hf_hub_id = 'timm/' ,
1410
1402
mean = IMAGENET_DEFAULT_MEAN ,
1411
1403
std = IMAGENET_DEFAULT_STD ,
1404
+ license = 'apache-2.0' ,
1412
1405
),
1413
1406
'vit_small_patch16_rope_ape_224.naver_in1k' : _cfg (
1414
- hf_hub_id = 'naver-ai/rope_axial_ape_deit_small_patch16_LS' ,
1415
- hf_hub_filename = 'pytorch_model.bin' ,
1407
+ hf_hub_id = 'timm/' ,
1416
1408
mean = IMAGENET_DEFAULT_MEAN ,
1417
1409
std = IMAGENET_DEFAULT_STD ,
1410
+ license = 'apache-2.0' ,
1418
1411
),
1419
1412
'vit_base_patch16_rope_ape_224.naver_in1k' : _cfg (
1420
- hf_hub_id = 'naver-ai/rope_axial_ape_deit_base_patch16_LS' ,
1421
- hf_hub_filename = 'pytorch_model.bin' ,
1413
+ hf_hub_id = 'timm/' ,
1422
1414
mean = IMAGENET_DEFAULT_MEAN ,
1423
1415
std = IMAGENET_DEFAULT_STD ,
1416
+ license = 'apache-2.0' ,
1424
1417
),
1425
1418
'vit_large_patch16_rope_ape_224.naver_in1k' : _cfg (
1426
- hf_hub_id = 'naver-ai/rope_axial_ape_deit_large_patch16_LS' ,
1427
- hf_hub_filename = 'pytorch_model.bin' ,
1419
+ hf_hub_id = 'timm/' ,
1428
1420
mean = IMAGENET_DEFAULT_MEAN ,
1429
1421
std = IMAGENET_DEFAULT_STD ,
1422
+ license = 'apache-2.0' ,
1430
1423
),
1431
- 'vit_small_patch16_mrope_ape_224.naver_in1k' : _cfg (
1432
- hf_hub_id = 'naver-ai/rope_mixed_ape_deit_small_patch16_LS' ,
1433
- hf_hub_filename = 'pytorch_model.bin' ,
1424
+ 'vit_small_patch16_rope_mixed_ape_224.naver_in1k' : _cfg (
1425
+ hf_hub_id = 'timm/' ,
1434
1426
mean = IMAGENET_DEFAULT_MEAN ,
1435
1427
std = IMAGENET_DEFAULT_STD ,
1428
+ license = 'apache-2.0' ,
1436
1429
),
1437
- 'vit_base_patch16_mrope_ape_224.naver_in1k' : _cfg (
1438
- hf_hub_id = 'naver-ai/rope_mixed_ape_deit_base_patch16_LS' ,
1439
- hf_hub_filename = 'pytorch_model.bin' ,
1430
+ 'vit_base_patch16_rope_mixed_ape_224.naver_in1k' : _cfg (
1431
+ hf_hub_id = 'timm/' ,
1440
1432
mean = IMAGENET_DEFAULT_MEAN ,
1441
1433
std = IMAGENET_DEFAULT_STD ,
1434
+ license = 'apache-2.0' ,
1442
1435
),
1443
- 'vit_large_patch16_mrope_ape_224.naver_in1k' : _cfg (
1444
- hf_hub_id = 'naver-ai/rope_mixed_ape_deit_large_patch16_LS' ,
1445
- hf_hub_filename = 'pytorch_model.bin' ,
1436
+ 'vit_large_patch16_rope_mixed_ape_224.naver_in1k' : _cfg (
1437
+ hf_hub_id = 'timm/' ,
1446
1438
mean = IMAGENET_DEFAULT_MEAN ,
1447
1439
std = IMAGENET_DEFAULT_STD ,
1440
+ license = 'apache-2.0' ,
1448
1441
),
1449
1442
})
1450
1443
@@ -2023,7 +2016,7 @@ def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
2023
2016
2024
2017
2025
2018
@register_model
2026
- def vit_small_patch16_mrope_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2019
+ def vit_small_patch16_rope_mixed_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2027
2020
"""RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit"""
2028
2021
model_args = dict (
2029
2022
patch_size = 16 ,
@@ -2042,12 +2035,12 @@ def vit_small_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
2042
2035
rope_temperature = 10.0 ,
2043
2036
rope_mixed_mode = True ,
2044
2037
)
2045
- model = _create_eva ('vit_small_patch16_mrope_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2038
+ model = _create_eva ('vit_small_patch16_rope_mixed_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2046
2039
return model
2047
2040
2048
2041
2049
2042
@register_model
2050
- def vit_base_patch16_mrope_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2043
+ def vit_base_patch16_rope_mixed_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2051
2044
"""RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit"""
2052
2045
model_args = dict (
2053
2046
patch_size = 16 ,
@@ -2066,12 +2059,12 @@ def vit_base_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
2066
2059
rope_temperature = 10.0 ,
2067
2060
rope_mixed_mode = True ,
2068
2061
)
2069
- model = _create_eva ('vit_base_patch16_mrope_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2062
+ model = _create_eva ('vit_base_patch16_rope_mixed_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2070
2063
return model
2071
2064
2072
2065
2073
2066
@register_model
2074
- def vit_large_patch16_mrope_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2067
+ def vit_large_patch16_rope_mixed_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2075
2068
"""RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit"""
2076
2069
model_args = dict (
2077
2070
patch_size = 16 ,
@@ -2090,7 +2083,7 @@ def vit_large_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
2090
2083
rope_temperature = 10.0 ,
2091
2084
rope_mixed_mode = True ,
2092
2085
)
2093
- model = _create_eva ('vit_large_patch16_mrope_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2086
+ model = _create_eva ('vit_large_patch16_rope_mixed_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2094
2087
return model
2095
2088
2096
2089
@@ -2170,7 +2163,7 @@ def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2170
2163
2171
2164
2172
2165
@register_model
2173
- def vit_small_patch16_mrope_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2166
+ def vit_small_patch16_rope_mixed_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2174
2167
"""RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
2175
2168
model_args = dict (
2176
2169
patch_size = 16 ,
@@ -2191,12 +2184,12 @@ def vit_small_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2191
2184
rope_mixed_mode = True ,
2192
2185
)
2193
2186
2194
- model = _create_eva ('vit_small_patch16_mrope_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2187
+ model = _create_eva ('vit_small_patch16_rope_mixed_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2195
2188
return model
2196
2189
2197
2190
2198
2191
@register_model
2199
- def vit_base_patch16_mrope_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2192
+ def vit_base_patch16_rope_mixed_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2200
2193
"""RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
2201
2194
model_args = dict (
2202
2195
patch_size = 16 ,
@@ -2216,12 +2209,12 @@ def vit_base_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2216
2209
rope_temperature = 10.0 ,
2217
2210
rope_mixed_mode = True ,
2218
2211
)
2219
- model = _create_eva ('vit_base_patch16_mrope_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2212
+ model = _create_eva ('vit_base_patch16_rope_mixed_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2220
2213
return model
2221
2214
2222
2215
2223
2216
@register_model
2224
- def vit_large_patch16_mrope_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2217
+ def vit_large_patch16_rope_mixed_ape_224 (pretrained : bool = False , ** kwargs ) -> Eva :
2225
2218
"""RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
2226
2219
model_args = dict (
2227
2220
patch_size = 16 ,
@@ -2241,6 +2234,6 @@ def vit_large_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2241
2234
rope_temperature = 10.0 ,
2242
2235
rope_mixed_mode = True ,
2243
2236
)
2244
- model = _create_eva ('vit_large_patch16_mrope_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2237
+ model = _create_eva ('vit_large_patch16_rope_mixed_ape_224 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2245
2238
return model
2246
2239
0 commit comments