8
8
Implementation for timm by / Copyright 2024, Fredo Guan
9
9
"""
10
10
11
+ from collections import OrderedDict
11
12
from functools import partial
12
13
from typing import List , Final , Optional , Tuple
13
14
@@ -51,6 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, C, H, W] -> [B, C, H,
51
52
x = self .norm (x )
52
53
return x
53
54
55
+
56
+
54
57
class ConvProj (nn .Module ):
55
58
def __init__ (
56
59
self ,
@@ -65,7 +68,9 @@ def __init__(
65
68
) -> None :
66
69
super ().__init__ ()
67
70
self .dim = dim
68
-
71
+
72
+ # FIXME not working, bn layer outputs are incorrect
73
+ '''
69
74
self.conv_q = ConvNormAct(
70
75
dim,
71
76
dim,
@@ -78,7 +83,7 @@ def __init__(
78
83
act_layer=act_layer
79
84
)
80
85
81
- # TODO fuse kv conv?
86
+ # TODO fuse kv conv? don't wanna do weight remap
82
87
# TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary?
83
88
84
89
self.conv_k = ConvNormAct(
@@ -104,6 +109,40 @@ def __init__(
104
109
norm_layer=norm_layer,
105
110
act_layer=act_layer
106
111
)
112
+ '''
113
+ self .conv_q = nn .Sequential (OrderedDict ([
114
+ ('conv' , nn .Conv2d (
115
+ dim ,
116
+ dim ,
117
+ kernel_size = kernel_size ,
118
+ padding = padding ,
119
+ stride = stride_q ,
120
+ bias = bias ,
121
+ groups = dim
122
+ )),
123
+ ('bn' , nn .BatchNorm2d (dim )),]))
124
+ self .conv_k = nn .Sequential (OrderedDict ([
125
+ ('conv' , nn .Conv2d (
126
+ dim ,
127
+ dim ,
128
+ kernel_size = kernel_size ,
129
+ padding = padding ,
130
+ stride = stride_kv ,
131
+ bias = bias ,
132
+ groups = dim
133
+ )),
134
+ ('bn' , nn .BatchNorm2d (dim )),]))
135
+ self .conv_v = nn .Sequential (OrderedDict ([
136
+ ('conv' , nn .Conv2d (
137
+ dim ,
138
+ dim ,
139
+ kernel_size = kernel_size ,
140
+ padding = padding ,
141
+ stride = stride_kv ,
142
+ bias = bias ,
143
+ groups = dim
144
+ )),
145
+ ('bn' , nn .BatchNorm2d (dim )),]))
107
146
108
147
def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
109
148
B , C , H , W = x .shape
0 commit comments