diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index d0978c5610..9d91e8c1f4 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -78,6 +78,12 @@ def build_sincos2d_pos_embed( return pos_emb.to(dtype=dtype) +def swap_shape_xy(seq: List[int]) -> List[int]: + if len(seq) < 2: + return seq + return [seq[1], seq[0]] + list(seq[2:]) + + def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, @@ -134,6 +140,11 @@ def build_fourier_pos_embed( if dtype is None: dtype = bands.dtype + if grid_indexing == 'xy': + feat_shape = swap_shape_xy(feat_shape) + if ref_feat_shape is not None: + ref_feat_shape = swap_shape_xy(ref_feat_shape) + if in_pixels: t = [ torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) @@ -516,15 +527,16 @@ def init_random_2d_freqs( @torch.fx.wrap @register_notrace_function def get_mixed_grid( - height: int, - width: int, + shape: List[int], grid_indexing: str = 'ij', device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> Tuple[torch.Tensor, torch.Tensor]: + if grid_indexing == 'xy': + shape = swap_shape_xy(shape) x_pos, y_pos = torch.meshgrid( - torch.arange(height, dtype=dtype, device=device), - torch.arange(width, dtype=dtype, device=device), + torch.arange(shape[0], dtype=dtype, device=device), + torch.arange(shape[1], dtype=dtype, device=device), indexing=grid_indexing, ) t_x = x_pos.flatten() @@ -599,8 +611,7 @@ def __init__( if feat_shape is not None: # cache pre-computed grid t_x, t_y = get_mixed_grid( - feat_shape[0], - feat_shape[1], + feat_shape, grid_indexing=grid_indexing, device=self.freqs.device ) @@ -620,8 +631,7 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: """ if shape is not None: t_x, t_y = get_mixed_grid( - shape[0], - shape[1], + shape, grid_indexing=self.grid_indexing, device=self.freqs.device )