Skip to content

Commit 3b1604f

Browse files
committed
Update dependencyvit.py
1 parent 1107c69 commit 3b1604f

File tree

1 file changed

+15
-22
lines changed

1 file changed

+15
-22
lines changed

timm/models/dependencyvit.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" DependencyViT (FIXME WIP)
1+
""" DependencyViT
22
33
From-scratch implementation of DependencyViT in PyTorch
44
@@ -106,19 +106,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
106106
x = attn @ v
107107
x = x.transpose(1, 2).reshape(B, N, C)
108108

109-
'''
110-
# FIXME messy way to handle
111-
if self.track_dependency_mask or self.token_pruner:
112-
dependency_mask = attn.detach().sum(1) # [B, N, N]
113-
self.dependency_mask = dependency_mask if self.track_dependency_mask else None
114-
#FIXME how to prune
115-
x = self.token_pruner(x, dependency_mask.sum(-1)) if self.token_pruner else x # dependency mask weights(sum)
116-
#x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
117-
#x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
118-
#x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
119-
'''
109+
110+
#FIXME absolute value?
120111
self.dependency_mask = attn.detach().sum(1) if self.track_dependency_mask else None # [B, N, N]
121112

113+
#FIXME which pruning mask?
114+
115+
# [B, N]
122116
#prune_mask = attn.detach().sum(1).sum(-1)
123117
#prune_mask = attn.detach().sum(1).abs().sum(-1)
124118
#prune_mask = attn.detach().abs().sum((1, -1))
@@ -196,9 +190,9 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
196190
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
197191
return (x, m)
198192

199-
# FIXME lite model variants
200-
# FIXME toggle and retrieve dependency masks
193+
201194
# FIXME verify against reference impl
195+
# FIXME train weights that meet or exceed results from paper
202196

203197
class DependencyViT(VisionTransformer):
204198
def __init__(
@@ -207,24 +201,23 @@ def __init__(
207201
prune_ratio: Optional[float] = None,
208202
*args,
209203
**kwargs
210-
):
204+
): -> None:
211205
super().__init__(
212-
*args,
206+
*args,
213207
**kwargs,
214-
block_fn = DependencyViTBlock,
208+
block_fn = DependencyViTBlock,
215209
class_token=False,
216-
global_pool='avg',
217-
qkv_bias=False,
218-
init_values=1e-6,
210+
global_pool='avg',
211+
qkv_bias=False,
212+
init_values=1e-6,
219213
fc_norm=False,
220214
)
221215

222216
if prune_layers is not None:
223217
self.prune_layers = sorted(list(dict.fromkeys(prune_layers)))
224218
self.prune_ratio = prune_ratio
225219

226-
# FIXME reword these assertions
227-
assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices are greater than model depth"
220+
assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices exceed model depth"
228221
assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"
229222

230223
self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess

0 commit comments

Comments
 (0)