@@ -78,6 +78,12 @@ def build_sincos2d_pos_embed(
78
78
return pos_emb .to (dtype = dtype )
79
79
80
80
81
+ def swap_shape_xy (seq : List [int ]) -> List [int ]:
82
+ if len (seq ) < 2 :
83
+ return seq
84
+ return [seq [1 ], seq [0 ]] + seq [2 :]
85
+
86
+
81
87
def build_fourier_pos_embed (
82
88
feat_shape : List [int ],
83
89
bands : Optional [torch .Tensor ] = None ,
@@ -134,6 +140,11 @@ def build_fourier_pos_embed(
134
140
if dtype is None :
135
141
dtype = bands .dtype
136
142
143
+ if grid_indexing == 'xy' :
144
+ feat_shape = swap_shape_xy (feat_shape )
145
+ if ref_feat_shape is not None :
146
+ ref_feat_shape = swap_shape_xy (ref_feat_shape )
147
+
137
148
if in_pixels :
138
149
t = [
139
150
torch .linspace (- 1. , 1. , steps = s , device = device , dtype = torch .float32 )
@@ -516,15 +527,16 @@ def init_random_2d_freqs(
516
527
@torch .fx .wrap
517
528
@register_notrace_function
518
529
def get_mixed_grid (
519
- height : int ,
520
- width : int ,
530
+ shape : List [int ],
521
531
grid_indexing : str = 'ij' ,
522
532
device : Optional [torch .device ] = None ,
523
533
dtype : torch .dtype = torch .float32 ,
524
534
) -> Tuple [torch .Tensor , torch .Tensor ]:
535
+ if grid_indexing == 'xy' :
536
+ shape = swap_shape_xy (shape )
525
537
x_pos , y_pos = torch .meshgrid (
526
- torch .arange (height , dtype = dtype , device = device ),
527
- torch .arange (width , dtype = dtype , device = device ),
538
+ torch .arange (shape [ 0 ] , dtype = dtype , device = device ),
539
+ torch .arange (shape [ 1 ] , dtype = dtype , device = device ),
528
540
indexing = grid_indexing ,
529
541
)
530
542
t_x = x_pos .flatten ()
@@ -599,8 +611,7 @@ def __init__(
599
611
if feat_shape is not None :
600
612
# cache pre-computed grid
601
613
t_x , t_y = get_mixed_grid (
602
- feat_shape [0 ],
603
- feat_shape [1 ],
614
+ feat_shape ,
604
615
grid_indexing = grid_indexing ,
605
616
device = self .freqs .device
606
617
)
@@ -620,8 +631,7 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
620
631
"""
621
632
if shape is not None :
622
633
t_x , t_y = get_mixed_grid (
623
- shape [0 ],
624
- shape [1 ],
634
+ shape ,
625
635
grid_indexing = self .grid_indexing ,
626
636
device = self .freqs .device
627
637
)
0 commit comments