1
- """ DependencyViT (FIXME WIP)
1
+ """ DependencyViT
2
2
3
3
From-scratch implementation of DependencyViT in PyTorch
4
4
@@ -106,19 +106,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
106
106
x = attn @ v
107
107
x = x .transpose (1 , 2 ).reshape (B , N , C )
108
108
109
- '''
110
- # FIXME messy way to handle
111
- if self.track_dependency_mask or self.token_pruner:
112
- dependency_mask = attn.detach().sum(1) # [B, N, N]
113
- self.dependency_mask = dependency_mask if self.track_dependency_mask else None
114
- #FIXME how to prune
115
- x = self.token_pruner(x, dependency_mask.sum(-1)) if self.token_pruner else x # dependency mask weights(sum)
116
- #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
117
- #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
118
- #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
119
- '''
109
+
110
+ #FIXME absolute value?
120
111
self .dependency_mask = attn .detach ().sum (1 ) if self .track_dependency_mask else None # [B, N, N]
121
112
113
+ #FIXME which pruning mask?
114
+
115
+ # [B, N]
122
116
#prune_mask = attn.detach().sum(1).sum(-1)
123
117
#prune_mask = attn.detach().sum(1).abs().sum(-1)
124
118
#prune_mask = attn.detach().abs().sum((1, -1))
@@ -196,9 +190,9 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
196
190
x = x + self .drop_path2 (self .ls2 (self .mlp (self .norm2 (x ))))
197
191
return (x , m )
198
192
199
- # FIXME lite model variants
200
- # FIXME toggle and retrieve dependency masks
193
+
201
194
# FIXME verify against reference impl
195
+ # FIXME train weights that meet or exceed results from paper
202
196
203
197
class DependencyViT (VisionTransformer ):
204
198
def __init__ (
@@ -207,24 +201,23 @@ def __init__(
207
201
prune_ratio : Optional [float ] = None ,
208
202
* args ,
209
203
** kwargs
210
- ):
204
+ ): - > None :
211
205
super ().__init__ (
212
- * args ,
206
+ * args ,
213
207
** kwargs ,
214
- block_fn = DependencyViTBlock ,
208
+ block_fn = DependencyViTBlock ,
215
209
class_token = False ,
216
- global_pool = 'avg' ,
217
- qkv_bias = False ,
218
- init_values = 1e-6 ,
210
+ global_pool = 'avg' ,
211
+ qkv_bias = False ,
212
+ init_values = 1e-6 ,
219
213
fc_norm = False ,
220
214
)
221
215
222
216
if prune_layers is not None :
223
217
self .prune_layers = sorted (list (dict .fromkeys (prune_layers )))
224
218
self .prune_ratio = prune_ratio
225
219
226
- # FIXME reword these assertions
227
- assert max (self .prune_layers ) <= len (self .blocks ), "1 or more pruned layer indices are greater than model depth"
220
+ assert max (self .prune_layers ) <= len (self .blocks ), "1 or more pruned layer indices exceed model depth"
228
221
assert self .prune_ratio * len (self .prune_layers ) < 1 , "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"
229
222
230
223
self .prune_layers = [x - 1 for x in self .prune_layers ] # convert counting numbers to nn.Sequential indicess
0 commit comments