Skip to content

Commit 79bf1b2

Browse files
committed
Update dependencyvit.py
1 parent 46b204f commit 79bf1b2

File tree

1 file changed

+79
-8
lines changed

1 file changed

+79
-8
lines changed

timm/models/dependencyvit.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
Implementation for timm by / Copyright 2023, Fredo Guan
1111
"""
1212

13-
from typing import Any, Dict, Optional, Tuple
13+
import math
14+
from typing import Any, Dict, Optional, Tuple, Union
1415

1516
import torch
1617
import torch.nn as nn
@@ -26,8 +27,22 @@
2627

2728
__all__ = ['DependencyViT']
2829

30+
class TokenPruner(nn.Module):
31+
def __init__(
32+
self,
33+
prune_ratio: float,
34+
prune_index: int,
35+
):
36+
super().__init__()
37+
self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio)
38+
39+
def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N]
40+
_, N, C = x.shape
41+
topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False) # [B, N']
42+
topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C]
43+
return x.gather(1, topk_indices)
44+
2945

30-
# FIXME there is nearly no difference between this and stock attn, allowing sdpa to be used if a workaround can be found
3146
class ReversedAttention(nn.Module):
3247
dependency_mask: Optional[torch.Tensor]
3348

@@ -48,9 +63,9 @@ def __init__(
4863
self.scale = self.head_dim ** -0.5
4964
self.track_dependency_mask = False
5065
self.dependency_mask = None
51-
self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0
66+
self.head_selector_temperature = 0.1 # appendix D.1
5267

53-
self.head_selector = nn.Linear(dim, num_heads, bias=False)
68+
self.head_selector = nn.Linear(dim, num_heads, bias=False) # FIXME is there a bias term?
5469

5570
self.message_controller = Mlp(
5671
in_features = dim,
@@ -59,7 +74,9 @@ def __init__(
5974
act_layer = nn.GELU,
6075
bias = False, # FIXME is there a bias term?
6176
)
62-
77+
78+
self.token_pruner = None
79+
6380
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
6481
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
6582
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
@@ -86,8 +103,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
86103
attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa
87104
attn = attn * p * m # [B, n_h, N, N]
88105
x = attn @ v
89-
90-
self.dependency_mask = attn.sum(1) if self.track_dependency_mask else None
106+
107+
# FIXME messy way to handle
108+
if self.track_dependency_mask or not isinstance(self.token_pruner, nn.Identity()):
109+
dependency_mask = attn.detach().sum(1) # [B, N, N]
110+
self.dependency_mask = dependency_mask if self.track_dependency_mask else None
111+
#FIXME how to prune
112+
x = self.token_pruner(x, dependency_mask.sum(-1)) if self.token_pruner else x # dependency mask weights(sum)
113+
#x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum)
114+
#x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum)
115+
#x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
116+
91117

92118
x = x.transpose(1, 2).reshape(B, N, C)
93119
x = self.proj(x)
@@ -161,7 +187,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
161187
# FIXME verify against reference impl
162188

163189
class DependencyViT(VisionTransformer):
164-
def __init__(self, *args, **kwargs):
190+
def __init__(
191+
self,
192+
prune_layers: Optional[Union[List[int], Tuple[int]]] = None,
193+
prune_ratio: Optional[float] = None,
194+
*args,
195+
**kwargs
196+
):
165197
super().__init__(
166198
*args,
167199
**kwargs,
@@ -172,6 +204,19 @@ def __init__(self, *args, **kwargs):
172204
init_values=1e-6,
173205
fc_norm=False,
174206
)
207+
208+
if prune_layers is not None:
209+
self.prune_layers = sorted(list(dict.fromkeys(prune_layers)))
210+
self.prune_ratio = prune_ratio
211+
212+
# FIXME reword these assertions
213+
assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices are greater than model depth"
214+
assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1"
215+
216+
self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess
217+
for prune_index, layer in enumerate(prune_layers, 1):
218+
self.blocks[layer].attn.token_pruner = TokenPruner(self.prune_ratio, prune_index)
219+
175220

176221
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
177222
x = self.patch_embed(x)
@@ -191,6 +236,23 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
191236
x = self.norm(x)
192237
x = x * m.transpose(1, 3).squeeze(-1)
193238
return x
239+
240+
def track_dependency_mask(self, track: bool = True):
241+
for block in self.blocks:
242+
if block.attn.track_dependency_mask is not track:
243+
block.attn.dependency_mask = None
244+
block.attn.track_dependency_mask = track
245+
246+
def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None):
247+
# L' * [B, N, N]
248+
# L' * [B, N', N']
249+
result = []
250+
layers = range(len(self.blocks)) if not layers
251+
for layer in layers:
252+
result.append(self.blocks[layer].attn.dependency_mask)
253+
return result
254+
255+
194256

195257

196258
def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
@@ -212,6 +274,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
212274

213275
default_cfgs = {
214276
'dependencyvit_tiny_patch16_224.untrained': _cfg(url=''),
277+
'dependencyvit_small_patch16_224.untrained': _cfg(url=''),
278+
279+
'dependencyvit_lite_tiny_patch16_224.untrained': _cfg(url=''),
215280
}
216281

217282

@@ -240,4 +305,10 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend
240305
def dependencyvit_small_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT:
241306
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12)
242307
model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
308+
return model
309+
310+
@register_model
311+
def dependencyvit_lite_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT:
312+
model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, prune_layers=[2, 5, 8, 11], prune_ratio=0.16)
313+
model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
243314
return model

0 commit comments

Comments
 (0)