Skip to content

Commit 9ccf009

Browse files
committed
Update dependencyvit.py
1 parent e988a79 commit 9ccf009

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

timm/models/dependencyvit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
103103
attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa
104104
attn = attn * p * m # [B, n_h, N, N]
105105
x = attn @ v
106+
x = x.transpose(1, 2).reshape(B, N, C)
107+
106108

107109
# FIXME messy way to handle
108110
if self.track_dependency_mask or self.token_pruner:
@@ -115,7 +117,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
115117
#x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m
116118

117119

118-
x = x.transpose(1, 2).reshape(B, N, C)
120+
119121
x = self.proj(x)
120122
x = self.proj_drop(x)
121123
return (x, m)

0 commit comments

Comments
 (0)