Skip to content

Commit 1e8beb1

Browse files
committed
Update dependencyvit.py
1 parent 30c370e commit 1e8beb1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

timm/models/dependencyvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
8383
q = q * self.scale
8484
attn = q @ k.transpose(-2, -1)
8585
attn = attn.softmax(dim=-1)
86-
attn = self.attn_drop(attn).transpose(-2, -1)
86+
attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa
8787
attn = attn * p * m # [B, n_h, N, N]
8888
x = attn @ v
8989

0 commit comments

Comments
 (0)