@@ -462,55 +462,54 @@ def _interp2d(size):
462
462
def _apply_learned_naflex_pos_embed_grid_sample (
463
463
self ,
464
464
x : torch .Tensor ,
465
- naflex_grid_sizes : List [Tuple [int , int ]],
465
+ patch_coord : torch .Tensor ,
466
+ patch_valid : Optional [torch .Tensor ] = None ,
466
467
):
467
468
""" NaFlex 2D position embedding interpolation using F.grid_sample.
468
469
469
470
Based on proposal by https://github.com/stas-sl
470
471
"""
471
472
device = x .device
472
473
B , N , C = x .shape
474
+ shapes = patch_coord .max (dim = 1 ).values + 1 # (B, 2) containing [h_i, w_i]
473
475
474
- def _make_coords (h , w ):
475
- _y , _x = torch .meshgrid (
476
- torch .arange (h , device = device ),
477
- torch .arange (w , device = device ),
478
- indexing = 'ij' ,
479
- )
480
- coord = torch .stack ([_y .flatten (), _x .flatten ()], dim = 1 )
481
- return coord
482
-
483
- coords = torch .zeros (B , N , 2 , dtype = torch .long , device = device )
484
- for i , (h , w ) in enumerate (naflex_grid_sizes ):
485
- coords_i = _make_coords (h , w ) # (h*w, 2)
486
- coords [i , :coords_i .shape [0 ]] = coords_i # pad with zeros past h*w
487
- # FIXME should we be masking?
488
-
489
- shapes = coords .amax (1 ) + 1
490
- theta = torch .zeros (B , 2 , 3 , dtype = torch .float32 , device = device )
491
476
if self .pos_embed_ar_preserving :
492
- L = shapes .amax (1 )
493
- grid_max = L .amax ()
494
- grid_size = (grid_max , grid_max )
495
- theta [:, 0 , 0 ] = grid_size [1 ] / L # scale x
496
- theta [:, 1 , 1 ] = grid_size [0 ] / L # scale y
477
+ L_i = shapes .amax (dim = 1 ) # (B,) max(h_i, w_i)
478
+ L_global = L_i .amax ()
479
+ grid_size = (L_global , L_global )
480
+ s_x = s_y = L_global / L_i # uniform zoom (B,)
497
481
else :
498
- grid_size = shapes .amax (0 )
499
- theta [:, 0 , 0 ] = grid_size [1 ] / shapes [:, 1 ] # scale x
500
- theta [:, 1 , 1 ] = grid_size [0 ] / shapes [:, 0 ] # scale y
482
+ grid_size = shapes .amax (dim = 0 )
483
+ s_x = grid_size [1 ] / shapes [:, 1 ] # horizontal zoom (B,)
484
+ s_y = grid_size [0 ] / shapes [:, 0 ] # vertical zoom (B,)
485
+
486
+ theta = torch .zeros (B , 2 , 3 , device = device , dtype = torch .float32 )
487
+ theta [:, 0 , 0 ] = s_x # scale x
488
+ theta [:, 1 , 1 ] = s_y # scale y
501
489
theta [:, 0 , 2 ] = theta [:, 0 , 0 ] - 1 # translate x
502
490
theta [:, 1 , 2 ] = theta [:, 1 , 1 ] - 1 # translate y
491
+
503
492
grid = F .affine_grid (theta , (B , C , * grid_size ), align_corners = False )
504
493
pos_embed = F .grid_sample (
505
494
self .pos_embed .permute (0 , 3 , 1 , 2 ).expand (B , - 1 , - 1 , - 1 ).float (),
506
495
grid ,
507
496
mode = self .pos_embed_interp_mode ,
508
497
align_corners = False ,
509
498
padding_mode = 'border' ,
510
- ).to (dtype = x .dtype )
499
+ ).to (dtype = x .dtype ) # (B, C, H_out, W_out)
500
+
501
+ # NOTE if we bring in patch_valid, can explicitly mask padding tokens
502
+ # more experimentation at train time needed
503
+ # lin_idx = patch_coord[..., 0] * grid_size[1] + patch_coord[..., 1] # (B, N)
504
+ # pos_flat = pos_embed.flatten(2).transpose(1, 2)
505
+ # pos_flat = pos_flat.gather(1, lin_idx.unsqueeze(2).expand(-1, -1, C)) # (B, N, C)
506
+ # if patch_valid is not None:
507
+ # pos_flat.mul_(patch_valid.unsqueeze(2))
508
+ # idx_vec = torch.arange(N, device=device) # (N,)
509
+ # x.index_add_(1, idx_vec, pos_flat)
510
+
511
511
bi = torch .arange (B , device = device ).unsqueeze (1 )
512
- # NOTE leave as '+=', do not change to .add_(...)
513
- x += pos_embed [bi , :, coords [..., 0 ], coords [..., 1 ]]
512
+ x += pos_embed [bi , :, patch_coord [..., 0 ], patch_coord [..., 1 ]] # NOTE leave as '+='
514
513
515
514
def _apply_learned_pos_embed (
516
515
self ,
@@ -605,6 +604,7 @@ def forward(
605
604
self ,
606
605
x : torch .Tensor ,
607
606
patch_coord : Optional [torch .Tensor ] = None ,
607
+ patch_valid : Optional [torch .Tensor ] = None ,
608
608
) -> torch .Tensor :
609
609
"""Forward pass for patch embedding with position encoding.
610
610
@@ -676,7 +676,11 @@ def forward(
676
676
if self .pos_embed_type == 'learned' :
677
677
if naflex_grid_sizes is not None :
678
678
if self .pos_embed_use_grid_sample :
679
- self ._apply_learned_naflex_pos_embed_grid_sample (x , naflex_grid_sizes = naflex_grid_sizes )
679
+ self ._apply_learned_naflex_pos_embed_grid_sample (
680
+ x ,
681
+ patch_coord = patch_coord ,
682
+ patch_valid = patch_valid ,
683
+ )
680
684
else :
681
685
self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
682
686
else :
@@ -1146,7 +1150,7 @@ def forward_intermediates(
1146
1150
mask = create_attention_mask (patch_valid , self .num_prefix_tokens , patches .dtype )
1147
1151
1148
1152
# Forward pass through embedding
1149
- x = self .embeds (patches , patch_coord = patch_coord )
1153
+ x = self .embeds (patches , patch_coord = patch_coord , patch_valid = patch_valid )
1150
1154
x = self .norm_pre (x )
1151
1155
1152
1156
# Forward pass through blocks
@@ -1219,7 +1223,7 @@ def forward_features(
1219
1223
)
1220
1224
1221
1225
# Pass through embedding module with patch coordinate/type support
1222
- x = self .embeds (x , patch_coord = patch_coord )
1226
+ x = self .embeds (x , patch_coord = patch_coord , patch_valid = patch_valid )
1223
1227
x = self .norm_pre (x )
1224
1228
# Apply transformer blocks with masked attention if mask provided
1225
1229
if attn_mask is not None :
0 commit comments