Skip to content

Commit 5970607

Browse files
committed
Update dependencyvit.py
1 parent 5ae9513 commit 5970607

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

timm/models/dependencyvit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
180180
x = self._pos_embed(x)
181181
x = self.patch_drop(x)
182182
x = self.norm_pre(x)
183-
m = torch.Tensor(1).to(x)
183+
B, N, _ = x.shape
184+
m = torch.ones(B, 1, 1, N).to(x)
184185
if self.grad_checkpointing and not torch.jit.is_scripting():
185186
x, m = checkpoint_seq(self.blocks, (x, m))
186187
else:

0 commit comments

Comments
 (0)