Skip to content

Commit 5f3b70b

Browse files
committed
Update dependencyvit.py
1 parent 007cd95 commit 5f3b70b

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

timm/models/dependencyvit.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
bias = False, # FIXME is there a bias term?
7676
)
7777

78-
self.token_pruner = None
78+
#self.token_pruner = None
7979

8080
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
8181
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
@@ -105,7 +105,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
105105
x = attn @ v
106106
x = x.transpose(1, 2).reshape(B, N, C)
107107

108-
108+
'''
109109
# FIXME messy way to handle
110110
if self.track_dependency_mask or self.token_pruner:
111111
dependency_mask = attn.detach().sum(1) # [B, N, N]
@@ -115,12 +115,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
115115
#x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
116116
#x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
117117
#x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
118-
119-
118+
'''
119+
self.dependency_mask = attn.detach().sum(1) if self.track_dependency_mask else None # [B, N, N]
120+
121+
prune_mask = attn.detach().sum(1).sum(-1)
122+
#prune_mask = attn.detach().sum(1).abs().sum(-1)
123+
#prune_mask = attn.detach().abs().sum(1).sum(-1)
124+
#prune_mask = m.reshape(B, N)
120125

121126
x = self.proj(x)
122127
x = self.proj_drop(x)
123-
return (x, m)
128+
return (x, m, prune_mask)
124129

125130
class LayerScale(nn.Module):
126131
def __init__(
@@ -166,6 +171,8 @@ def __init__(
166171
)
167172
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
168173
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
174+
175+
self.token_pruner = None
169176

170177
self.norm2 = norm_layer(dim)
171178
self.mlp = mlp_layer(
@@ -179,8 +186,9 @@ def __init__(
179186

180187
def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
181188
x, m = in_tuple
182-
x_new, m = self.attn((self.norm1(x), m))
189+
x_new, m, prune_mask = self.attn((self.norm1(x), m))
183190
x = x + self.drop_path1(self.ls1(x_new))
191+
x = self.token_pruner(x, prune_mask) if self.token_pruner else x
184192
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
185193
return (x, m)
186194

@@ -217,7 +225,7 @@ def __init__(
217225

218226
self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess
219227
for prune_index, layer in enumerate(prune_layers, 1):
220-
self.blocks[layer].attn.token_pruner = TokenPruner(self.prune_ratio, prune_index)
228+
self.blocks[layer].token_pruner = TokenPruner(self.prune_ratio, prune_index)
221229

222230

223231
def forward_features(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)