Skip to content

Commit 007cd95

Browse files
committed
Update dependencyvit.py
1 parent 9ccf009 commit 007cd95

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/dependencyvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838

3939
def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N]
4040
_, N, C = x.shape
41-
topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False) # [B, N']
41+
topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N']
4242
topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C]
4343
return x.gather(1, topk_indices)
4444

0 commit comments

Comments
 (0)