@@ -32,11 +32,12 @@ def __init__(
32
32
self ,
33
33
prune_ratio : float ,
34
34
prune_index : int ,
35
- ):
35
+ ) -> None :
36
36
super ().__init__ ()
37
37
self .pct_kept_tokens = (1 - prune_index * prune_ratio ) / (1 - (prune_index - 1 ) * prune_ratio )
38
38
39
- def forward (self , x : torch .Tensor , m : torch .Tensor , scores : torch .Tensor ): # [B, N, C], [B, 1, 1, N], [B, N]
39
+ # [B, N, C], [B, 1, 1, N], [B, N] -> [B, N', C], [B, 1, 1, N']
40
+ def forward (self , x : torch .Tensor , m : torch .Tensor , scores : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
40
41
B , N , C = x .shape
41
42
topk_indices = scores .topk (math .floor (self .pct_kept_tokens * N ), sorted = False )[1 ] # [B, N']
42
43
x = x .gather (1 , topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C )) # [B, N', C]
@@ -86,8 +87,8 @@ def __init__(
86
87
self .proj_drop = nn .Dropout (proj_drop )
87
88
88
89
# m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1)
89
- def forward ( self , in_tuple : Tuple [ torch . Tensor , torch . Tensor ]) -> Tuple [ torch . Tensor , torch . Tensor ]:
90
- x , m = in_tuple # [B, N, C], [B, 1, 1, N]
90
+ # [B, N, C], [B, 1, 1, N] -> [B, N, C], [B, 1, 1, N], [B, N]
91
+ def forward ( self , x : torch . Tensor , m : torch . Tensor ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ]:
91
92
B , N , C = x .shape
92
93
qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
93
94
q , k , v = qkv .unbind (0 )
@@ -112,7 +113,6 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
112
113
113
114
#FIXME which pruning mask?
114
115
115
- # [B, N]
116
116
#prune_mask = attn.detach().sum(1).sum(-1)
117
117
#prune_mask = attn.detach().sum(1).abs().sum(-1)
118
118
#prune_mask = attn.detach().abs().sum((1, -1))
@@ -184,7 +184,7 @@ def __init__(
184
184
185
185
def forward (self , in_tuple : Tuple [torch .Tensor , torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor ]:
186
186
x , m = in_tuple
187
- x_new , m , prune_mask = self .attn (( self .norm1 (x ), m ) )
187
+ x_new , m , prune_mask = self .attn (self .norm1 (x ), m )
188
188
x = x + self .drop_path1 (self .ls1 (x_new ))
189
189
x , m = self .token_pruner (x , m , prune_mask ) if self .token_pruner else (x , m )
190
190
x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
@@ -201,7 +201,7 @@ def __init__(
201
201
prune_ratio : Optional [float ] = None ,
202
202
* args ,
203
203
** kwargs
204
- ): - > None :
204
+ ) -> None :
205
205
super ().__init__ (
206
206
* args ,
207
207
** kwargs ,
@@ -244,13 +244,13 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
244
244
x = x * m .transpose (1 , 3 ).squeeze (- 1 )
245
245
return x
246
246
247
- def track_dependency_mask (self , track : bool = True ):
247
+ def track_dependency_mask (self , track : bool = True ) -> None :
248
248
for block in self .blocks :
249
249
if block .attn .track_dependency_mask is not track :
250
250
block .attn .dependency_mask = None
251
251
block .attn .track_dependency_mask = track
252
252
253
- def get_dependency_mask (self , layers : Optional [Union [List [int ], Tuple [int ]]] = None ):
253
+ def get_dependency_mask (self , layers : Optional [Union [List [int ], Tuple [int ]]] = None ) -> List [ torch . Tensor ] :
254
254
# L' * [B, N, N]
255
255
# L' * [B, N', N']
256
256
result = []
0 commit comments