@@ -75,7 +75,7 @@ def __init__(
75
75
bias = False , # FIXME is there a bias term?
76
76
)
77
77
78
- self .token_pruner = None
78
+ # self.token_pruner = None
79
79
80
80
self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
81
81
self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
@@ -105,7 +105,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
105
105
x = attn @ v
106
106
x = x .transpose (1 , 2 ).reshape (B , N , C )
107
107
108
-
108
+ '''
109
109
# FIXME messy way to handle
110
110
if self.track_dependency_mask or self.token_pruner:
111
111
dependency_mask = attn.detach().sum(1) # [B, N, N]
@@ -115,12 +115,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
115
115
#x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
116
116
#x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
117
117
#x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
118
-
119
-
118
+ '''
119
+ self .dependency_mask = attn .detach ().sum (1 ) if self .track_dependency_mask else None # [B, N, N]
120
+
121
+ prune_mask = attn .detach ().sum (1 ).sum (- 1 )
122
+ #prune_mask = attn.detach().sum(1).abs().sum(-1)
123
+ #prune_mask = attn.detach().abs().sum(1).sum(-1)
124
+ #prune_mask = m.reshape(B, N)
120
125
121
126
x = self .proj (x )
122
127
x = self .proj_drop (x )
123
- return (x , m )
128
+ return (x , m , prune_mask )
124
129
125
130
class LayerScale (nn .Module ):
126
131
def __init__ (
@@ -166,6 +171,8 @@ def __init__(
166
171
)
167
172
self .ls1 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
168
173
self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
174
+
175
+ self .token_pruner = None
169
176
170
177
self .norm2 = norm_layer (dim )
171
178
self .mlp = mlp_layer (
@@ -179,8 +186,9 @@ def __init__(
179
186
180
187
def forward (self , in_tuple : Tuple [torch .Tensor , torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor ]:
181
188
x , m = in_tuple
182
- x_new , m = self .attn ((self .norm1 (x ), m ))
189
+ x_new , m , prune_mask = self .attn ((self .norm1 (x ), m ))
183
190
x = x + self .drop_path1 (self .ls1 (x_new ))
191
+ x = self .token_pruner (x , prune_mask ) if self .token_pruner else x
184
192
x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
185
193
return (x , m )
186
194
@@ -217,7 +225,7 @@ def __init__(
217
225
218
226
self .prune_layers = [x - 1 for x in self .prune_layers ] # convert counting numbers to nn.Sequential indicess
219
227
for prune_index , layer in enumerate (prune_layers , 1 ):
220
- self .blocks [layer ].attn . token_pruner = TokenPruner (self .prune_ratio , prune_index )
228
+ self .blocks [layer ].token_pruner = TokenPruner (self .prune_ratio , prune_index )
221
229
222
230
223
231
def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
0 commit comments