10
10
Implementation for timm by / Copyright 2023, Fredo Guan
11
11
"""
12
12
13
- from typing import Any , Dict , Optional , Tuple
13
+ import math
14
+ from typing import Any , Dict , Optional , Tuple , Union
14
15
15
16
import torch
16
17
import torch .nn as nn
26
27
27
28
__all__ = ['DependencyViT' ]
28
29
30
+ class TokenPruner (nn .Module ):
31
+ def __init__ (
32
+ self ,
33
+ prune_ratio : float ,
34
+ prune_index : int ,
35
+ ):
36
+ super ().__init__ ()
37
+ self .pct_kept_tokens = (1 - prune_index * prune_ratio ) / (1 - (prune_index - 1 ) * prune_ratio )
38
+
39
+ def forward (self , x : torch .Tensor , scores : torch .Tensor ): # [B, N, C], [B, N]
40
+ _ , N , C = x .shape
41
+ topk_indices = scores .topk (math .floor (self .pct_kept_tokens * N ), sorted = False ) # [B, N']
42
+ topk_indices = topk_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , C ) # [B, N', C]
43
+ return x .gather (1 , topk_indices )
44
+
29
45
30
- # FIXME there is nearly no difference between this and stock attn, allowing sdpa to be used if a workaround can be found
31
46
class ReversedAttention (nn .Module ):
32
47
dependency_mask : Optional [torch .Tensor ]
33
48
@@ -48,9 +63,9 @@ def __init__(
48
63
self .scale = self .head_dim ** - 0.5
49
64
self .track_dependency_mask = False
50
65
self .dependency_mask = None
51
- self .head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0
66
+ self .head_selector_temperature = 0.1 # appendix D.1
52
67
53
- self .head_selector = nn .Linear (dim , num_heads , bias = False )
68
+ self .head_selector = nn .Linear (dim , num_heads , bias = False ) # FIXME is there a bias term?
54
69
55
70
self .message_controller = Mlp (
56
71
in_features = dim ,
@@ -59,7 +74,9 @@ def __init__(
59
74
act_layer = nn .GELU ,
60
75
bias = False , # FIXME is there a bias term?
61
76
)
62
-
77
+
78
+ self .token_pruner = None
79
+
63
80
self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
64
81
self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
65
82
self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
@@ -86,8 +103,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
86
103
attn = self .attn_drop (attn ).transpose (- 2 , - 1 ) # this transpose prevents use of sdpa
87
104
attn = attn * p * m # [B, n_h, N, N]
88
105
x = attn @ v
89
-
90
- self .dependency_mask = attn .sum (1 ) if self .track_dependency_mask else None
106
+
107
+ # FIXME messy way to handle
108
+ if self .track_dependency_mask or not isinstance (self .token_pruner , nn .Identity ()):
109
+ dependency_mask = attn .detach ().sum (1 ) # [B, N, N]
110
+ self .dependency_mask = dependency_mask if self .track_dependency_mask else None
111
+ #FIXME how to prune
112
+ x = self .token_pruner (x , dependency_mask .sum (- 1 )) if self .token_pruner else x # dependency mask weights(sum)
113
+ #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
114
+ #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
115
+ #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
116
+
91
117
92
118
x = x .transpose (1 , 2 ).reshape (B , N , C )
93
119
x = self .proj (x )
@@ -161,7 +187,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
161
187
# FIXME verify against reference impl
162
188
163
189
class DependencyViT (VisionTransformer ):
164
- def __init__ (self , * args , ** kwargs ):
190
+ def __init__ (
191
+ self ,
192
+ prune_layers : Optional [Union [List [int ], Tuple [int ]]] = None ,
193
+ prune_ratio : Optional [float ] = None ,
194
+ * args ,
195
+ ** kwargs
196
+ ):
165
197
super ().__init__ (
166
198
* args ,
167
199
** kwargs ,
@@ -172,6 +204,19 @@ def __init__(self, *args, **kwargs):
172
204
init_values = 1e-6 ,
173
205
fc_norm = False ,
174
206
)
207
+
208
+ if prune_layers is not None :
209
+ self .prune_layers = sorted (list (dict .fromkeys (prune_layers )))
210
+ self .prune_ratio = prune_ratio
211
+
212
+ # FIXME reword these assertions
213
+ assert max (self .prune_layers ) <= len (self .blocks ), "1 or more pruned layer indices are greater than model depth"
214
+ assert self .prune_ratio * len (self .prune_layers ) < 1 , "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"
215
+
216
+ self .prune_layers = [x - 1 for x in self .prune_layers ] # convert counting numbers to nn.Sequential indicess
217
+ for prune_index , layer in enumerate (prune_layers , 1 ):
218
+ self .blocks [layer ].attn .token_pruner = TokenPruner (self .prune_ratio , prune_index )
219
+
175
220
176
221
def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
177
222
x = self .patch_embed (x )
@@ -191,6 +236,23 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
191
236
x = self .norm (x )
192
237
x = x * m .transpose (1 , 3 ).squeeze (- 1 )
193
238
return x
239
+
240
+ def track_dependency_mask (self , track : bool = True ):
241
+ for block in self .blocks :
242
+ if block .attn .track_dependency_mask is not track :
243
+ block .attn .dependency_mask = None
244
+ block .attn .track_dependency_mask = track
245
+
246
+ def get_dependency_mask (self , layers : Optional [Union [List [int ], Tuple [int ]]] = None ):
247
+ # L' * [B, N, N]
248
+ # L' * [B, N', N']
249
+ result = []
250
+ layers = range (len (self .blocks )) if not layers
251
+ for layer in layers :
252
+ result .append (self .blocks [layer ].attn .dependency_mask )
253
+ return result
254
+
255
+
194
256
195
257
196
258
def _cfg (url : str = '' , ** kwargs ) -> Dict [str , Any ]:
@@ -212,6 +274,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
212
274
213
275
default_cfgs = {
214
276
'dependencyvit_tiny_patch16_224.untrained' : _cfg (url = '' ),
277
+ 'dependencyvit_small_patch16_224.untrained' : _cfg (url = '' ),
278
+
279
+ 'dependencyvit_lite_tiny_patch16_224.untrained' : _cfg (url = '' ),
215
280
}
216
281
217
282
@@ -240,4 +305,10 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend
240
305
def dependencyvit_small_patch16_224 (pretrained : bool = False , ** kwargs ) -> DependencyViT :
241
306
model_args = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 12 )
242
307
model = _create_dependencyvit ('dependencyvit_tiny_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
308
+ return model
309
+
310
+ @register_model
311
+ def dependencyvit_lite_tiny_patch16_224 (pretrained : bool = False , ** kwargs ) -> DependencyViT :
312
+ model_args = dict (patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 12 , prune_layers = [2 , 5 , 8 , 11 ], prune_ratio = 0.16 )
313
+ model = _create_dependencyvit ('dependencyvit_tiny_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
243
314
return model
0 commit comments