Skip to content

Commit 1cdedea

Browse files
committed
Update cvt.py
1 parent c63ee94 commit 1cdedea

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

timm/models/cvt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def __init__(
268268
)
269269
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
270270
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
271+
self.probe = nn.Identity()
271272

272273
def add_cls_token(
273274
self,
@@ -296,6 +297,7 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t
296297
cls_token, x = torch.split(x, [1, H*W], 1)
297298

298299
x = x.transpose(1, 2).reshape(B, C, H, W)
300+
self.probe(x)
299301

300302
return x, cls_token
301303

0 commit comments

Comments
 (0)