Skip to content

Commit 2dd2ec5

Browse files
committed
fix syntax, type and shape annotations
1 parent 3b1604f commit 2dd2ec5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

timm/models/dependencyvit.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ def __init__(
3232
self,
3333
prune_ratio: float,
3434
prune_index: int,
35-
):
35+
) -> None:
3636
super().__init__()
3737
self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio)
3838

39-
def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, 1, 1, N], [B, N]
39+
# [B, N, C], [B, 1, 1, N], [B, N] -> [B, N', C], [B, 1, 1, N']
40+
def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
4041
B, N, C = x.shape
4142
topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N']
4243
x = x.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, C)) # [B, N', C]
@@ -86,8 +87,8 @@ def __init__(
8687
self.proj_drop = nn.Dropout(proj_drop)
8788

8889
# m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1)
89-
def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
90-
x, m = in_tuple # [B, N, C], [B, 1, 1, N]
90+
# [B, N, C], [B, 1, 1, N] -> [B, N, C], [B, 1, 1, N], [B, N]
91+
def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
9192
B, N, C = x.shape
9293
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
9394
q, k, v = qkv.unbind(0)
@@ -112,7 +113,6 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
112113

113114
#FIXME which pruning mask?
114115

115-
# [B, N]
116116
#prune_mask = attn.detach().sum(1).sum(-1)
117117
#prune_mask = attn.detach().sum(1).abs().sum(-1)
118118
#prune_mask = attn.detach().abs().sum((1, -1))
@@ -184,7 +184,7 @@ def __init__(
184184

185185
def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
186186
x, m = in_tuple
187-
x_new, m, prune_mask = self.attn((self.norm1(x), m))
187+
x_new, m, prune_mask = self.attn(self.norm1(x), m)
188188
x = x + self.drop_path1(self.ls1(x_new))
189189
x, m = self.token_pruner(x, m, prune_mask) if self.token_pruner else (x, m)
190190
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
@@ -201,7 +201,7 @@ def __init__(
201201
prune_ratio: Optional[float] = None,
202202
*args,
203203
**kwargs
204-
): -> None:
204+
) -> None:
205205
super().__init__(
206206
*args,
207207
**kwargs,
@@ -244,13 +244,13 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
244244
x = x * m.transpose(1, 3).squeeze(-1)
245245
return x
246246

247-
def track_dependency_mask(self, track: bool = True):
247+
def track_dependency_mask(self, track: bool = True) -> None:
248248
for block in self.blocks:
249249
if block.attn.track_dependency_mask is not track:
250250
block.attn.dependency_mask = None
251251
block.attn.track_dependency_mask = track
252252

253-
def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None):
253+
def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None) -> List[torch.Tensor]:
254254
# L' * [B, N, N]
255255
# L' * [B, N', N']
256256
result = []

0 commit comments

Comments
 (0)