Skip to content

Commit 25a501d

Browse files
committed
Update dependencyvit.py
1 parent 5f3b70b commit 25a501d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

timm/models/dependencyvit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ def __init__(
3636
super().__init__()
3737
self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio)
3838

39-
def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N]
40-
_, N, C = x.shape
39+
def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, 1, 1, N], [B, N]
40+
B, N, C = x.shape
4141
topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N']
42-
topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C]
43-
return x.gather(1, topk_indices)
42+
x = x.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, C)) # [B, N', C]
43+
m = m.gather(3, topk_indices.unsqueeze(1).unsqueeze(1)) # [B, 1, 1, N']
44+
return (x, m)
4445

4546

4647
class ReversedAttention(nn.Module):
@@ -188,7 +189,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
188189
x, m = in_tuple
189190
x_new, m, prune_mask = self.attn((self.norm1(x), m))
190191
x = x + self.drop_path1(self.ls1(x_new))
191-
x = self.token_pruner(x, prune_mask) if self.token_pruner else x
192+
x, m = self.token_pruner(x, m, prune_mask) if self.token_pruner else (x, m)
192193
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
193194
return (x, m)
194195

0 commit comments

Comments
 (0)