From 60059a07b885c3f6de44096a7adcd44c5e8e949f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 9 Jul 2025 09:37:43 -0700 Subject: [PATCH 1/2] Fix H, W ordering for xy indexing in ROPE, impacts models w/ xy indexing and non-square images --- timm/layers/pos_embed_sincos.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index d0978c5610..478310078e 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -78,6 +78,12 @@ def build_sincos2d_pos_embed( return pos_emb.to(dtype=dtype) +def swap_shape_xy(seq: List[int]) -> List[int]: + if len(seq) < 2: + return seq + return [seq[1], seq[0]] + seq[2:] + + def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, @@ -134,6 +140,11 @@ def build_fourier_pos_embed( if dtype is None: dtype = bands.dtype + if grid_indexing == 'xy': + feat_shape = swap_shape_xy(feat_shape) + if ref_feat_shape is not None: + ref_feat_shape = swap_shape_xy(ref_feat_shape) + if in_pixels: t = [ torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) @@ -516,15 +527,16 @@ def init_random_2d_freqs( @torch.fx.wrap @register_notrace_function def get_mixed_grid( - height: int, - width: int, + shape: List[int], grid_indexing: str = 'ij', device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: + if grid_indexing == 'xy': + shape = swap_shape_xy(shape) x_pos, y_pos = torch.meshgrid( - torch.arange(height, dtype=dtype, device=device), - torch.arange(width, dtype=dtype, device=device), + torch.arange(shape[0], dtype=dtype, device=device), + torch.arange(shape[1], dtype=dtype, device=device), indexing=grid_indexing, ) t_x = x_pos.flatten() @@ -599,8 +611,7 @@ def __init__( if feat_shape is not None: # cache pre-computed grid t_x, t_y = get_mixed_grid( - feat_shape[0], - feat_shape[1], + feat_shape, grid_indexing=grid_indexing, device=self.freqs.device ) @@ -620,8 +631,7 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: """ if shape is not None: t_x, t_y = get_mixed_grid( - shape[0], - shape[1], + shape, grid_indexing=self.grid_indexing, device=self.freqs.device ) From 893aa5d147bb690a0c371fb7b9aff7e30fdbc8c8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 9 Jul 2025 11:19:06 -0700 Subject: [PATCH 2/2] Fix scripting? --- timm/layers/pos_embed_sincos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 478310078e..9d91e8c1f4 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -81,7 +81,7 @@ def build_sincos2d_pos_embed( def swap_shape_xy(seq: List[int]) -> List[int]: if len(seq) < 2: return seq - return [seq[1], seq[0]] + seq[2:] + return [seq[1], seq[0]] + list(seq[2:]) def build_fourier_pos_embed(