Skip to content

Commit 9393921

Browse files
committed
Cache t_x/t_y for rope mixed, use more like other rope embeds, rename rope_mixed -> mrope in model names
1 parent b9a336f commit 9393921

File tree

2 files changed

+67
-36
lines changed

2 files changed

+67
-36
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,8 @@ def __init__(
365365
)
366366

367367
def get_embed(self, shape: Optional[List[int]] = None):
368-
if self.bands is not None:
368+
if shape is not None and self.bands is not None:
369369
# rebuild embeddings every call, use if target shape changes
370-
assert shape is not None
371370
return build_rotary_pos_embed(
372371
shape,
373372
self.bands,
@@ -376,8 +375,10 @@ def get_embed(self, shape: Optional[List[int]] = None):
376375
grid_offset=self.grid_offset,
377376
grid_indexing=self.grid_indexing,
378377
)
379-
else:
378+
elif self.pos_embed_sin is not None and self.pos_embed_cos is not None:
380379
return self.pos_embed_sin, self.pos_embed_cos
380+
else:
381+
assert False, "get_embed() requires pre-computed pos embeds or valid shape w/ pre-computed bands"
381382

382383
def forward(self, x):
383384
# assuming channel-first tensor where spatial dim are >= 2
@@ -456,7 +457,7 @@ def __init__(
456457
)
457458

458459
def get_embed(self, shape: Optional[List[int]] = None):
459-
if self.bands is not None and shape is not None:
460+
if shape is not None and self.bands is not None:
460461
# rebuild embeddings every call, use if target shape changes
461462
embeds = build_rotary_pos_embed(
462463
shape,
@@ -470,7 +471,7 @@ def get_embed(self, shape: Optional[List[int]] = None):
470471
elif self.pos_embed is not None:
471472
return self.pos_embed
472473
else:
473-
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
474+
assert False, "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"
474475

475476
def forward(self, x):
476477
# assuming channel-first tensor where spatial dim are >= 2
@@ -514,31 +515,39 @@ def init_random_2d_freqs(
514515

515516
@torch.fx.wrap
516517
@register_notrace_function
517-
def get_mixed_freqs(
518-
freqs: torch.Tensor,
518+
def get_mixed_grid(
519519
height: int,
520520
width: int,
521521
grid_indexing: str = 'ij',
522-
):
523-
"""Compute mixed (learnable) frequencies."""
524-
# Create position indices
525-
device = freqs.device
526-
dtype = freqs.dtype
522+
device: Optional[torch.device] = None,
523+
dtype: torch.dtype = torch.float32,
524+
) -> Tuple[torch.Tensor, torch.Tensor]:
527525
x_pos, y_pos = torch.meshgrid(
528526
torch.arange(height, dtype=dtype, device=device),
529527
torch.arange(width, dtype=dtype, device=device),
530528
indexing=grid_indexing,
531529
)
532530
t_x = x_pos.flatten()
533531
t_y = y_pos.flatten()
532+
return t_x, t_y
533+
534+
535+
def get_mixed_freqs(
536+
freqs: torch.Tensor,
537+
t_x: torch.Tensor,
538+
t_y: torch.Tensor,
539+
) -> torch.Tensor:
540+
"""Compute mixed (learnable) frequencies."""
541+
# Create position indices
542+
dtype = freqs.dtype
543+
freqs = freqs.float()
534544
freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
535545
freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
536546
combined = freqs_x + freqs_y # shape: (num_heads, N, dim//4)
537547
sin_emb = torch.sin(combined).repeat_interleave(2, -1) # (N, dim//2)
538548
cos_emb = torch.cos(combined).repeat_interleave(2, -1) # (N, dim//2)
539549
rope_embeds = torch.cat([sin_emb, cos_emb], dim=-1) # (num_heads, H*W, head_dim)
540-
541-
return rope_embeds
550+
return rope_embeds.to(dtype)
542551

543552

544553
class RotaryEmbeddingMixed(nn.Module):
@@ -584,6 +593,18 @@ def __init__(
584593
rotate=True,
585594
) # (2, depth, num_heads, head_dim//2)
586595
self.freqs = nn.Parameter(freqs)
596+
if feat_shape is not None:
597+
# cache pre-computed grid
598+
t_x, t_y = get_mixed_grid(
599+
feat_shape[0],
600+
feat_shape[1],
601+
grid_indexing=grid_indexing,
602+
device=self.freqs.device
603+
)
604+
self.register_buffer('t_x', t_x, persistent=False)
605+
self.register_buffer('t_y', t_y, persistent=False)
606+
else:
607+
self.t_x = self.t_y = None
587608

588609
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
589610
"""Generate rotary embeddings for the given spatial shape.
@@ -594,9 +615,19 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
594615
Returns:
595616
Tensor of shape (depth, H*W, dim) containing concatenated sin/cos embeddings
596617
"""
597-
assert shape is not None, "shape must be provided"
598-
H, W = shape
599-
return get_mixed_freqs(self.freqs, height=H, width=W, grid_indexing=self.grid_indexing)
618+
if shape is not None:
619+
t_x, t_y = get_mixed_grid(
620+
shape[0],
621+
shape[1],
622+
grid_indexing=self.grid_indexing,
623+
device=self.freqs.device
624+
)
625+
elif self.t_x is not None and self.t_y is not None:
626+
t_x, t_y = self.t_x, self.t_y
627+
else:
628+
assert False, "get_embed() requires pre-computed t_x/t_y or valid shape"
629+
630+
return get_mixed_freqs(self.freqs, t_x, t_y)
600631

601632
def forward(self, x):
602633
# assuming channel-first tensor where spatial dim are >= 2

timm/models/eva.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
708708
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
709709
else:
710710
pos_embed = self.pos_embed
711-
rot_pos_embed = self.rope.get_embed(shape=self.patch_embed.grid_size) if self.rope is not None else None
711+
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
712712

713713
to_cat = []
714714
if self.cls_token is not None:
@@ -1392,19 +1392,19 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
13921392
mean=IMAGENET_DEFAULT_MEAN,
13931393
std=IMAGENET_DEFAULT_STD,
13941394
),
1395-
'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg(
1395+
'vit_small_patch16_mrope_224.naver_in1k': _cfg(
13961396
hf_hub_id='naver-ai/rope_mixed_deit_small_patch16_LS',
13971397
hf_hub_filename='pytorch_model.bin',
13981398
mean=IMAGENET_DEFAULT_MEAN,
13991399
std=IMAGENET_DEFAULT_STD,
14001400
),
1401-
'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg(
1401+
'vit_base_patch16_mrope_224.naver_in1k': _cfg(
14021402
hf_hub_id='naver-ai/rope_mixed_deit_base_patch16_LS',
14031403
hf_hub_filename='pytorch_model.bin',
14041404
mean=IMAGENET_DEFAULT_MEAN,
14051405
std=IMAGENET_DEFAULT_STD,
14061406
),
1407-
'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg(
1407+
'vit_large_patch16_mrope_224.naver_in1k': _cfg(
14081408
hf_hub_id='naver-ai/rope_mixed_deit_large_patch16_LS',
14091409
hf_hub_filename='pytorch_model.bin',
14101410
mean=IMAGENET_DEFAULT_MEAN,
@@ -1428,19 +1428,19 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
14281428
mean=IMAGENET_DEFAULT_MEAN,
14291429
std=IMAGENET_DEFAULT_STD,
14301430
),
1431-
'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
1431+
'vit_small_patch16_mrope_ape_224.naver_in1k': _cfg(
14321432
hf_hub_id='naver-ai/rope_mixed_ape_deit_small_patch16_LS',
14331433
hf_hub_filename='pytorch_model.bin',
14341434
mean=IMAGENET_DEFAULT_MEAN,
14351435
std=IMAGENET_DEFAULT_STD,
14361436
),
1437-
'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
1437+
'vit_base_patch16_mrope_ape_224.naver_in1k': _cfg(
14381438
hf_hub_id='naver-ai/rope_mixed_ape_deit_base_patch16_LS',
14391439
hf_hub_filename='pytorch_model.bin',
14401440
mean=IMAGENET_DEFAULT_MEAN,
14411441
std=IMAGENET_DEFAULT_STD,
14421442
),
1443-
'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
1443+
'vit_large_patch16_mrope_ape_224.naver_in1k': _cfg(
14441444
hf_hub_id='naver-ai/rope_mixed_ape_deit_large_patch16_LS',
14451445
hf_hub_filename='pytorch_model.bin',
14461446
mean=IMAGENET_DEFAULT_MEAN,
@@ -2023,7 +2023,7 @@ def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
20232023

20242024

20252025
@register_model
2026-
def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
2026+
def vit_small_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20272027
"""RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit"""
20282028
model_args = dict(
20292029
patch_size=16,
@@ -2042,12 +2042,12 @@ def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
20422042
rope_temperature=10.0,
20432043
rope_mixed_mode=True,
20442044
)
2045-
model = _create_eva('vit_small_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
2045+
model = _create_eva('vit_small_patch16_mrope_224', pretrained=pretrained, **dict(model_args, **kwargs))
20462046
return model
20472047

20482048

20492049
@register_model
2050-
def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
2050+
def vit_base_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20512051
"""RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit"""
20522052
model_args = dict(
20532053
patch_size=16,
@@ -2066,12 +2066,12 @@ def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
20662066
rope_temperature=10.0,
20672067
rope_mixed_mode=True,
20682068
)
2069-
model = _create_eva('vit_base_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
2069+
model = _create_eva('vit_base_patch16_mrope_224', pretrained=pretrained, **dict(model_args, **kwargs))
20702070
return model
20712071

20722072

20732073
@register_model
2074-
def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
2074+
def vit_large_patch16_mrope_224(pretrained: bool = False, **kwargs) -> Eva:
20752075
"""RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit"""
20762076
model_args = dict(
20772077
patch_size=16,
@@ -2090,7 +2090,7 @@ def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
20902090
rope_temperature=10.0,
20912091
rope_mixed_mode=True,
20922092
)
2093-
model = _create_eva('vit_large_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
2093+
model = _create_eva('vit_large_patch16_mrope_224', pretrained=pretrained, **dict(model_args, **kwargs))
20942094
return model
20952095

20962096

@@ -2170,7 +2170,7 @@ def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21702170

21712171

21722172
@register_model
2173-
def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2173+
def vit_small_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
21742174
"""RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
21752175
model_args = dict(
21762176
patch_size=16,
@@ -2191,12 +2191,12 @@ def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) ->
21912191
rope_mixed_mode=True,
21922192
)
21932193

2194-
model = _create_eva('vit_small_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
2194+
model = _create_eva('vit_small_patch16_mrope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
21952195
return model
21962196

21972197

21982198
@register_model
2199-
def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2199+
def vit_base_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22002200
"""RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
22012201
model_args = dict(
22022202
patch_size=16,
@@ -2216,12 +2216,12 @@ def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> E
22162216
rope_temperature=10.0,
22172217
rope_mixed_mode=True,
22182218
)
2219-
model = _create_eva('vit_base_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
2219+
model = _create_eva('vit_base_patch16_mrope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
22202220
return model
22212221

22222222

22232223
@register_model
2224-
def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
2224+
def vit_large_patch16_mrope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
22252225
"""RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
22262226
model_args = dict(
22272227
patch_size=16,
@@ -2241,6 +2241,6 @@ def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) ->
22412241
rope_temperature=10.0,
22422242
rope_mixed_mode=True,
22432243
)
2244-
model = _create_eva('vit_large_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
2244+
model = _create_eva('vit_large_patch16_mrope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
22452245
return model
22462246

0 commit comments

Comments
 (0)