Skip to content

Commit 6c896b1

Browse files
committed
Update cvt.py
1 parent e3e3b3f commit 6c896b1

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

timm/models/cvt.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,13 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T
288288

289289
def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
290290
B, C, H, W = x.shape
291-
292-
x = (torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2)) \
293-
+ self.drop_path1(self.ls1(self.fw_attn(self.norm1(x), cls_token)))
291+
res = torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2)
292+
293+
x = self.norm1(torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2))
294+
if self.use_cls_token:
295+
cls_token, x = torch.split(x, [1, H*W], 1)
296+
297+
x = res + self.drop_path1(self.ls1(self.fw_attn(x.transpose(1, 2).reshape(B, C, H, W), cls_token)))
294298
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
295299

296300
if self.use_cls_token:

0 commit comments

Comments
 (0)