|
| 1 | +from random import randrange |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn, einsum |
| 5 | +from torch.nn import Module, ModuleList |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +from einops import rearrange, repeat, pack, unpack |
| 9 | +from einops.layers.torch import Rearrange |
| 10 | + |
| 11 | +# helpers |
| 12 | + |
| 13 | +def exists(val): |
| 14 | + return val is not None |
| 15 | + |
| 16 | +def pack_one(t, pattern): |
| 17 | + return pack([t], pattern) |
| 18 | + |
| 19 | +def unpack_one(t, ps, pattern): |
| 20 | + return unpack(t, ps, pattern)[0] |
| 21 | + |
| 22 | +def l2norm(t): |
| 23 | + return F.normalize(t, dim = -1, p = 2) |
| 24 | + |
| 25 | +def dropout_layers(layers, dropout): |
| 26 | + if dropout == 0: |
| 27 | + return layers |
| 28 | + |
| 29 | + num_layers = len(layers) |
| 30 | + to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout |
| 31 | + |
| 32 | + # make sure at least one layer makes it |
| 33 | + if all(to_drop): |
| 34 | + rand_index = randrange(num_layers) |
| 35 | + to_drop[rand_index] = False |
| 36 | + |
| 37 | + layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop] |
| 38 | + return layers |
| 39 | + |
| 40 | +# classes |
| 41 | + |
| 42 | +class LayerScale(Module): |
| 43 | + def __init__(self, dim, fn, depth): |
| 44 | + super().__init__() |
| 45 | + if depth <= 18: |
| 46 | + init_eps = 0.1 |
| 47 | + elif 18 > depth <= 24: |
| 48 | + init_eps = 1e-5 |
| 49 | + else: |
| 50 | + init_eps = 1e-6 |
| 51 | + |
| 52 | + self.fn = fn |
| 53 | + self.scale = nn.Parameter(torch.full((dim,), init_eps)) |
| 54 | + |
| 55 | + def forward(self, x, **kwargs): |
| 56 | + return self.fn(x, **kwargs) * self.scale |
| 57 | + |
| 58 | +class FeedForward(Module): |
| 59 | + def __init__(self, dim, hidden_dim, dropout = 0.): |
| 60 | + super().__init__() |
| 61 | + self.net = nn.Sequential( |
| 62 | + nn.LayerNorm(dim), |
| 63 | + nn.Linear(dim, hidden_dim), |
| 64 | + nn.GELU(), |
| 65 | + nn.Dropout(dropout), |
| 66 | + nn.Linear(hidden_dim, dim), |
| 67 | + nn.Dropout(dropout) |
| 68 | + ) |
| 69 | + def forward(self, x): |
| 70 | + return self.net(x) |
| 71 | + |
| 72 | +class Attention(Module): |
| 73 | + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): |
| 74 | + super().__init__() |
| 75 | + inner_dim = dim_head * heads |
| 76 | + self.heads = heads |
| 77 | + self.scale = dim_head ** -0.5 |
| 78 | + |
| 79 | + self.norm = nn.LayerNorm(dim) |
| 80 | + self.to_q = nn.Linear(dim, inner_dim, bias = False) |
| 81 | + self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) |
| 82 | + |
| 83 | + self.attend = nn.Softmax(dim = -1) |
| 84 | + self.dropout = nn.Dropout(dropout) |
| 85 | + |
| 86 | + self.to_out = nn.Sequential( |
| 87 | + nn.Linear(inner_dim, dim), |
| 88 | + nn.Dropout(dropout) |
| 89 | + ) |
| 90 | + |
| 91 | + def forward(self, x, context = None): |
| 92 | + h = self.heads |
| 93 | + |
| 94 | + x = self.norm(x) |
| 95 | + context = x if not exists(context) else torch.cat((x, context), dim = 1) |
| 96 | + |
| 97 | + qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) |
| 98 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) |
| 99 | + |
| 100 | + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale |
| 101 | + |
| 102 | + attn = self.attend(sim) |
| 103 | + attn = self.dropout(attn) |
| 104 | + |
| 105 | + out = einsum('b h i j, b h j d -> b h i d', attn, v) |
| 106 | + out = rearrange(out, 'b h n d -> b n (h d)') |
| 107 | + return self.to_out(out) |
| 108 | + |
| 109 | +class XCAttention(Module): |
| 110 | + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): |
| 111 | + super().__init__() |
| 112 | + inner_dim = dim_head * heads |
| 113 | + self.heads = heads |
| 114 | + self.norm = nn.LayerNorm(dim) |
| 115 | + |
| 116 | + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
| 117 | + |
| 118 | + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) |
| 119 | + |
| 120 | + self.attend = nn.Softmax(dim = -1) |
| 121 | + self.dropout = nn.Dropout(dropout) |
| 122 | + |
| 123 | + self.to_out = nn.Sequential( |
| 124 | + nn.Linear(inner_dim, dim), |
| 125 | + nn.Dropout(dropout) |
| 126 | + ) |
| 127 | + |
| 128 | + def forward(self, x): |
| 129 | + h = self.heads |
| 130 | + x, ps = pack_one(x, 'b * d') |
| 131 | + |
| 132 | + x = self.norm(x) |
| 133 | + q, k, v = self.to_qkv(x).chunk(3, dim = -1) |
| 134 | + |
| 135 | + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v)) |
| 136 | + |
| 137 | + q, k = map(l2norm, (q, k)) |
| 138 | + |
| 139 | + sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp() |
| 140 | + |
| 141 | + attn = self.attend(sim) |
| 142 | + attn = self.dropout(attn) |
| 143 | + |
| 144 | + out = einsum('b h i j, b h j n -> b h i n', attn, v) |
| 145 | + out = rearrange(out, 'b h d n -> b n (h d)') |
| 146 | + |
| 147 | + out = unpack_one(out, ps, 'b * d') |
| 148 | + return self.to_out(out) |
| 149 | + |
| 150 | +class LocalPatchInteraction(Module): |
| 151 | + def __init__(self, dim, kernel_size = 3): |
| 152 | + super().__init__() |
| 153 | + assert (kernel_size % 2) == 1 |
| 154 | + padding = kernel_size // 2 |
| 155 | + |
| 156 | + self.net = nn.Sequential( |
| 157 | + nn.LayerNorm(dim), |
| 158 | + Rearrange('b h w c -> b c h w'), |
| 159 | + nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim), |
| 160 | + nn.BatchNorm2d(dim), |
| 161 | + nn.GELU(), |
| 162 | + nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim), |
| 163 | + Rearrange('b c h w -> b h w c'), |
| 164 | + ) |
| 165 | + |
| 166 | + def forward(self, x): |
| 167 | + return self.net(x) |
| 168 | + |
| 169 | +class Transformer(Module): |
| 170 | + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.): |
| 171 | + super().__init__() |
| 172 | + self.layers = ModuleList([]) |
| 173 | + self.layer_dropout = layer_dropout |
| 174 | + |
| 175 | + for ind in range(depth): |
| 176 | + layer = ind + 1 |
| 177 | + self.layers.append(ModuleList([ |
| 178 | + LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer), |
| 179 | + LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer) |
| 180 | + ])) |
| 181 | + |
| 182 | + def forward(self, x, context = None): |
| 183 | + layers = dropout_layers(self.layers, dropout = self.layer_dropout) |
| 184 | + |
| 185 | + for attn, ff in layers: |
| 186 | + x = attn(x, context = context) + x |
| 187 | + x = ff(x) + x |
| 188 | + |
| 189 | + return x |
| 190 | + |
| 191 | +class XCATransformer(Module): |
| 192 | + def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.): |
| 193 | + super().__init__() |
| 194 | + self.layers = ModuleList([]) |
| 195 | + self.layer_dropout = layer_dropout |
| 196 | + |
| 197 | + for ind in range(depth): |
| 198 | + layer = ind + 1 |
| 199 | + self.layers.append(ModuleList([ |
| 200 | + LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer), |
| 201 | + LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer), |
| 202 | + LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer) |
| 203 | + ])) |
| 204 | + |
| 205 | + def forward(self, x): |
| 206 | + layers = dropout_layers(self.layers, dropout = self.layer_dropout) |
| 207 | + |
| 208 | + for cross_covariance_attn, local_patch_interaction, ff in layers: |
| 209 | + x = cross_covariance_attn(x) + x |
| 210 | + x = local_patch_interaction(x) + x |
| 211 | + x = ff(x) + x |
| 212 | + |
| 213 | + return x |
| 214 | + |
| 215 | +class XCiT(Module): |
| 216 | + def __init__( |
| 217 | + self, |
| 218 | + *, |
| 219 | + image_size, |
| 220 | + patch_size, |
| 221 | + num_classes, |
| 222 | + dim, |
| 223 | + depth, |
| 224 | + cls_depth, |
| 225 | + heads, |
| 226 | + mlp_dim, |
| 227 | + dim_head = 64, |
| 228 | + dropout = 0., |
| 229 | + emb_dropout = 0., |
| 230 | + local_patch_kernel_size = 3, |
| 231 | + layer_dropout = 0. |
| 232 | + ): |
| 233 | + super().__init__() |
| 234 | + assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' |
| 235 | + |
| 236 | + num_patches = (image_size // patch_size) ** 2 |
| 237 | + patch_dim = 3 * patch_size ** 2 |
| 238 | + |
| 239 | + self.to_patch_embedding = nn.Sequential( |
| 240 | + Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size), |
| 241 | + nn.LayerNorm(patch_dim), |
| 242 | + nn.Linear(patch_dim, dim), |
| 243 | + nn.LayerNorm(dim) |
| 244 | + ) |
| 245 | + |
| 246 | + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) |
| 247 | + self.cls_token = nn.Parameter(torch.randn(dim)) |
| 248 | + |
| 249 | + self.dropout = nn.Dropout(emb_dropout) |
| 250 | + |
| 251 | + self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout) |
| 252 | + |
| 253 | + self.final_norm = nn.LayerNorm(dim) |
| 254 | + |
| 255 | + self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout) |
| 256 | + |
| 257 | + self.mlp_head = nn.Sequential( |
| 258 | + nn.LayerNorm(dim), |
| 259 | + nn.Linear(dim, num_classes) |
| 260 | + ) |
| 261 | + |
| 262 | + def forward(self, img): |
| 263 | + x = self.to_patch_embedding(img) |
| 264 | + |
| 265 | + x, ps = pack_one(x, 'b * d') |
| 266 | + |
| 267 | + b, n, _ = x.shape |
| 268 | + x += self.pos_embedding[:, :n] |
| 269 | + |
| 270 | + x = unpack_one(x, ps, 'b * d') |
| 271 | + |
| 272 | + x = self.dropout(x) |
| 273 | + |
| 274 | + x = self.xcit_transformer(x) |
| 275 | + |
| 276 | + x = self.final_norm(x) |
| 277 | + |
| 278 | + cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b) |
| 279 | + |
| 280 | + x = rearrange(x, 'b ... d -> b (...) d') |
| 281 | + cls_tokens = self.cls_transformer(cls_tokens, context = x) |
| 282 | + |
| 283 | + return self.mlp_head(cls_tokens[:, 0]) |
0 commit comments