Skip to content

Commit ab0c06c

Browse files
committed
Fix up grid_sample, did not make sense to rebuild patch coords, duh
1 parent 4e3cba8 commit ab0c06c

File tree

1 file changed

+36
-32
lines changed

1 file changed

+36
-32
lines changed

timm/models/naflexvit.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -462,55 +462,54 @@ def _interp2d(size):
462462
def _apply_learned_naflex_pos_embed_grid_sample(
463463
self,
464464
x: torch.Tensor,
465-
naflex_grid_sizes: List[Tuple[int, int]],
465+
patch_coord: torch.Tensor,
466+
patch_valid: Optional[torch.Tensor] = None,
466467
):
467468
""" NaFlex 2D position embedding interpolation using F.grid_sample.
468469
469470
Based on proposal by https://github.com/stas-sl
470471
"""
471472
device = x.device
472473
B, N, C = x.shape
474+
shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i]
473475

474-
def _make_coords(h, w):
475-
_y, _x = torch.meshgrid(
476-
torch.arange(h, device=device),
477-
torch.arange(w, device=device),
478-
indexing='ij',
479-
)
480-
coord = torch.stack([_y.flatten(), _x.flatten()], dim=1)
481-
return coord
482-
483-
coords = torch.zeros(B, N, 2, dtype=torch.long, device=device)
484-
for i, (h, w) in enumerate(naflex_grid_sizes):
485-
coords_i = _make_coords(h, w) # (h*w, 2)
486-
coords[i, :coords_i.shape[0]] = coords_i # pad with zeros past h*w
487-
# FIXME should we be masking?
488-
489-
shapes = coords.amax(1) + 1
490-
theta = torch.zeros(B, 2, 3, dtype=torch.float32, device=device)
491476
if self.pos_embed_ar_preserving:
492-
L = shapes.amax(1)
493-
grid_max = L.amax()
494-
grid_size = (grid_max, grid_max)
495-
theta[:, 0, 0] = grid_size[1] / L # scale x
496-
theta[:, 1, 1] = grid_size[0] / L # scale y
477+
L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
478+
L_global = L_i.amax()
479+
grid_size = (L_global, L_global)
480+
s_x = s_y = L_global / L_i # uniform zoom (B,)
497481
else:
498-
grid_size = shapes.amax(0)
499-
theta[:, 0, 0] = grid_size[1] / shapes[:, 1] # scale x
500-
theta[:, 1, 1] = grid_size[0] / shapes[:, 0] # scale y
482+
grid_size = shapes.amax(dim=0)
483+
s_x = grid_size[1] / shapes[:, 1] # horizontal zoom (B,)
484+
s_y = grid_size[0] / shapes[:, 0] # vertical zoom (B,)
485+
486+
theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32)
487+
theta[:, 0, 0] = s_x # scale x
488+
theta[:, 1, 1] = s_y # scale y
501489
theta[:, 0, 2] = theta[:, 0, 0] - 1 # translate x
502490
theta[:, 1, 2] = theta[:, 1, 1] - 1 # translate y
491+
503492
grid = F.affine_grid(theta, (B, C, *grid_size), align_corners=False)
504493
pos_embed = F.grid_sample(
505494
self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(),
506495
grid,
507496
mode=self.pos_embed_interp_mode,
508497
align_corners=False,
509498
padding_mode='border',
510-
).to(dtype=x.dtype)
499+
).to(dtype=x.dtype) # (B, C, H_out, W_out)
500+
501+
# NOTE if we bring in patch_valid, can explicitly mask padding tokens
502+
# more experimentation at train time needed
503+
# lin_idx = patch_coord[..., 0] * grid_size[1] + patch_coord[..., 1] # (B, N)
504+
# pos_flat = pos_embed.flatten(2).transpose(1, 2)
505+
# pos_flat = pos_flat.gather(1, lin_idx.unsqueeze(2).expand(-1, -1, C)) # (B, N, C)
506+
# if patch_valid is not None:
507+
# pos_flat.mul_(patch_valid.unsqueeze(2))
508+
# idx_vec = torch.arange(N, device=device) # (N,)
509+
# x.index_add_(1, idx_vec, pos_flat)
510+
511511
bi = torch.arange(B, device=device).unsqueeze(1)
512-
# NOTE leave as '+=', do not change to .add_(...)
513-
x += pos_embed[bi, :, coords[..., 0], coords[..., 1]]
512+
x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+='
514513

515514
def _apply_learned_pos_embed(
516515
self,
@@ -605,6 +604,7 @@ def forward(
605604
self,
606605
x: torch.Tensor,
607606
patch_coord: Optional[torch.Tensor] = None,
607+
patch_valid: Optional[torch.Tensor] = None,
608608
) -> torch.Tensor:
609609
"""Forward pass for patch embedding with position encoding.
610610
@@ -676,7 +676,11 @@ def forward(
676676
if self.pos_embed_type == 'learned':
677677
if naflex_grid_sizes is not None:
678678
if self.pos_embed_use_grid_sample:
679-
self._apply_learned_naflex_pos_embed_grid_sample(x, naflex_grid_sizes=naflex_grid_sizes)
679+
self._apply_learned_naflex_pos_embed_grid_sample(
680+
x,
681+
patch_coord=patch_coord,
682+
patch_valid=patch_valid,
683+
)
680684
else:
681685
self._apply_learned_naflex_pos_embed(x, naflex_grid_sizes=naflex_grid_sizes)
682686
else:
@@ -1146,7 +1150,7 @@ def forward_intermediates(
11461150
mask = create_attention_mask(patch_valid, self.num_prefix_tokens, patches.dtype)
11471151

11481152
# Forward pass through embedding
1149-
x = self.embeds(patches, patch_coord=patch_coord)
1153+
x = self.embeds(patches, patch_coord=patch_coord, patch_valid=patch_valid)
11501154
x = self.norm_pre(x)
11511155

11521156
# Forward pass through blocks
@@ -1219,7 +1223,7 @@ def forward_features(
12191223
)
12201224

12211225
# Pass through embedding module with patch coordinate/type support
1222-
x = self.embeds(x, patch_coord=patch_coord)
1226+
x = self.embeds(x, patch_coord=patch_coord, patch_valid=patch_valid)
12231227
x = self.norm_pre(x)
12241228
# Apply transformer blocks with masked attention if mask provided
12251229
if attn_mask is not None:

0 commit comments

Comments
 (0)