Skip to content

Commit 832c155

Browse files
committed
Update cvt.py
1 parent b06907b commit 832c155

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

timm/models/cvt.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,14 @@ def __init__(
378378
trunc_normal_(self.cls_token, std=.02)
379379

380380
def forward(self, x: torch.Tensor) -> torch.Tensor:
381-
x = self.probe(x)
381+
382382
x = self.conv_embed(x)
383+
x = self.probe(x)
383384
x = self.embed_drop(x)
384385

385-
cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.cls_token is not None else None
386+
cls_token = self.embed_drop(
387+
self.cls_token.expand(x.shape[0], -1, -1)
388+
) if self.cls_token is not None else None
386389
for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
387390
x, cls_token = block(x, cls_token)
388391

0 commit comments

Comments
 (0)