We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b06907b commit 832c155Copy full SHA for 832c155
timm/models/cvt.py
@@ -378,11 +378,14 @@ def __init__(
378
trunc_normal_(self.cls_token, std=.02)
379
380
def forward(self, x: torch.Tensor) -> torch.Tensor:
381
- x = self.probe(x)
+
382
x = self.conv_embed(x)
383
+ x = self.probe(x)
384
x = self.embed_drop(x)
385
- cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.cls_token is not None else None
386
+ cls_token = self.embed_drop(
387
+ self.cls_token.expand(x.shape[0], -1, -1)
388
+ ) if self.cls_token is not None else None
389
for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor
390
x, cls_token = block(x, cls_token)
391
0 commit comments