Skip to content

Commit 33791c8

Browse files
committed
Fix H, W ordering for xy indexing in ROPE, impacts models w/ xy indexing and non-square images
1 parent a7c5368 commit 33791c8

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ def build_sincos2d_pos_embed(
7878
return pos_emb.to(dtype=dtype)
7979

8080

81+
def swap_shape_xy(seq: List[int]) -> List[int]:
82+
if len(seq) < 2:
83+
return seq
84+
return [seq[1], seq[0]] + seq[2:]
85+
86+
8187
def build_fourier_pos_embed(
8288
feat_shape: List[int],
8389
bands: Optional[torch.Tensor] = None,
@@ -134,6 +140,11 @@ def build_fourier_pos_embed(
134140
if dtype is None:
135141
dtype = bands.dtype
136142

143+
if grid_indexing == 'xy':
144+
feat_shape = swap_shape_xy(feat_shape)
145+
if ref_feat_shape is not None:
146+
ref_feat_shape = swap_shape_xy(ref_feat_shape)
147+
137148
if in_pixels:
138149
t = [
139150
torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32)
@@ -516,15 +527,16 @@ def init_random_2d_freqs(
516527
@torch.fx.wrap
517528
@register_notrace_function
518529
def get_mixed_grid(
519-
height: int,
520-
width: int,
530+
shape: List[int],
521531
grid_indexing: str = 'ij',
522532
device: Optional[torch.device] = None,
523533
dtype: torch.dtype = torch.float32,
524534
) -> Tuple[torch.Tensor, torch.Tensor]:
535+
if grid_indexing == 'xy':
536+
shape = swap_shape_xy(shape)
525537
x_pos, y_pos = torch.meshgrid(
526-
torch.arange(height, dtype=dtype, device=device),
527-
torch.arange(width, dtype=dtype, device=device),
538+
torch.arange(shape[0], dtype=dtype, device=device),
539+
torch.arange(shape[1], dtype=dtype, device=device),
528540
indexing=grid_indexing,
529541
)
530542
t_x = x_pos.flatten()
@@ -599,8 +611,7 @@ def __init__(
599611
if feat_shape is not None:
600612
# cache pre-computed grid
601613
t_x, t_y = get_mixed_grid(
602-
feat_shape[0],
603-
feat_shape[1],
614+
feat_shape,
604615
grid_indexing=grid_indexing,
605616
device=self.freqs.device
606617
)
@@ -620,8 +631,7 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
620631
"""
621632
if shape is not None:
622633
t_x, t_y = get_mixed_grid(
623-
shape[0],
624-
shape[1],
634+
shape,
625635
grid_indexing=self.grid_indexing,
626636
device=self.freqs.device
627637
)

0 commit comments

Comments
 (0)