Skip to content

Commit 7a33ff4

Browse files
committed
Update cvt.py
1 parent a1c4c1e commit 7a33ff4

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

timm/models/cvt.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Implementation for timm by / Copyright 2024, Fredo Guan
99
"""
1010

11+
from collections import OrderedDict
1112
from functools import partial
1213
from typing import List, Final, Optional, Tuple
1314

@@ -51,6 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, C, H, W] -> [B, C, H,
5152
x = self.norm(x)
5253
return x
5354

55+
56+
5457
class ConvProj(nn.Module):
5558
def __init__(
5659
self,
@@ -65,7 +68,9 @@ def __init__(
6568
) -> None:
6669
super().__init__()
6770
self.dim = dim
68-
71+
72+
# FIXME not working, bn layer outputs are incorrect
73+
'''
6974
self.conv_q = ConvNormAct(
7075
dim,
7176
dim,
@@ -78,7 +83,7 @@ def __init__(
7883
act_layer=act_layer
7984
)
8085
81-
# TODO fuse kv conv?
86+
# TODO fuse kv conv? don't wanna do weight remap
8287
# TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary?
8388
8489
self.conv_k = ConvNormAct(
@@ -104,6 +109,40 @@ def __init__(
104109
norm_layer=norm_layer,
105110
act_layer=act_layer
106111
)
112+
'''
113+
self.conv_q = nn.Sequential(OrderedDict([
114+
('conv', nn.Conv2d(
115+
dim,
116+
dim,
117+
kernel_size=kernel_size,
118+
padding=padding,
119+
stride=stride_q,
120+
bias=bias,
121+
groups=dim
122+
)),
123+
('bn', nn.BatchNorm2d(dim)),]))
124+
self.conv_k = nn.Sequential(OrderedDict([
125+
('conv', nn.Conv2d(
126+
dim,
127+
dim,
128+
kernel_size=kernel_size,
129+
padding=padding,
130+
stride=stride_kv,
131+
bias=bias,
132+
groups=dim
133+
)),
134+
('bn', nn.BatchNorm2d(dim)),]))
135+
self.conv_v = nn.Sequential(OrderedDict([
136+
('conv', nn.Conv2d(
137+
dim,
138+
dim,
139+
kernel_size=kernel_size,
140+
padding=padding,
141+
stride=stride_kv,
142+
bias=bias,
143+
groups=dim
144+
)),
145+
('bn', nn.BatchNorm2d(dim)),]))
107146

108147
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
109148
B, C, H, W = x.shape

0 commit comments

Comments
 (0)