@@ -36,11 +36,12 @@ def __init__(
36
36
super ().__init__ ()
37
37
self .pct_kept_tokens = (1 - prune_index * prune_ratio ) / (1 - (prune_index - 1 ) * prune_ratio )
38
38
39
- def forward (self , x : torch .Tensor , scores : torch .Tensor ): # [B, N, C], [B, N]
40
- _ , N , C = x .shape
39
+ def forward (self , x : torch .Tensor , m : torch . Tensor , scores : torch .Tensor ): # [B, N, C], [B, 1, 1, N ], [B, N]
40
+ B , N , C = x .shape
41
41
topk_indices = scores .topk (math .floor (self .pct_kept_tokens * N ), sorted = False )[1 ] # [B, N']
42
- topk_indices = topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C ) # [B, N', C]
43
- return x .gather (1 , topk_indices )
42
+ x = x .gather (1 , topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C )) # [B, N', C]
43
+ m = m .gather (3 , topk_indices .unsqueeze (1 ).unsqueeze (1 )) # [B, 1, 1, N']
44
+ return (x , m )
44
45
45
46
46
47
class ReversedAttention (nn .Module ):
@@ -188,7 +189,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
188
189
x , m = in_tuple
189
190
x_new , m , prune_mask = self .attn ((self .norm1 (x ), m ))
190
191
x = x + self .drop_path1 (self .ls1 (x_new ))
191
- x = self .token_pruner (x , prune_mask ) if self .token_pruner else x
192
+ x , m = self .token_pruner (x , m , prune_mask ) if self .token_pruner else ( x , m )
192
193
x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
193
194
return (x , m )
194
195
0 commit comments