Skip to content

Commit 0aadb30

Browse files
committed
Update cvt.py
1 parent 187208f commit 0aadb30

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

timm/models/cvt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,13 @@ def __init__(
372372
)
373373
blocks.append(block)
374374
self.blocks = nn.ModuleList(blocks)
375+
self.probe = nn.Identity()
375376

376377
if self.cls_token is not None:
377378
trunc_normal_(self.cls_token, std=.02)
378379

379380
def forward(self, x: torch.Tensor) -> torch.Tensor:
381+
self.probe(x)
380382
x = self.conv_embed(x)
381383
x = self.embed_drop(x)
382384

0 commit comments

Comments
 (0)