Skip to content

Commit 1616288

Browse files
authored
add xcit (#284)
* add xcit * use Rearrange layers * give cross correlation transformer a final norm at end * document
1 parent 9e1e824 commit 1616288

File tree

3 files changed

+327
-1
lines changed

3 files changed

+327
-1
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- [MaxViT](#maxvit)
2626
- [NesT](#nest)
2727
- [MobileViT](#mobilevit)
28+
- [XCiT](#xcit)
2829
- [Masked Autoencoder](#masked-autoencoder)
2930
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
3031
- [Masked Patch Prediction](#masked-patch-prediction)
@@ -772,6 +773,38 @@ img = torch.randn(1, 3, 256, 256)
772773
pred = mbvit_xs(img) # (1, 1000)
773774
```
774775

776+
## XCiT
777+
778+
<img src="./images/xcit.png" width="400px"></img>
779+
780+
This <a href="https://arxiv.org/abs/2106.09681">paper</a> introduces the cross correlation attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).
781+
782+
Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.
783+
784+
```python
785+
import torch
786+
from vit_pytorch.xcit import XCiT
787+
788+
v = XCiT(
789+
image_size = 256,
790+
patch_size = 32,
791+
num_classes = 1000,
792+
dim = 1024,
793+
depth = 12, # depth of xcit transformer
794+
cls_depth = 2, # depth of cross attention of CLS tokens to patch, attention pool at end
795+
heads = 16,
796+
mlp_dim = 2048,
797+
dropout = 0.1,
798+
emb_dropout = 0.1,
799+
layer_dropout = 0.05, # randomly dropout 5% of the layers
800+
local_patch_kernel_size = 3 # kernel size of the local patch interaction module (depthwise convs)
801+
)
802+
803+
img = torch.randn(1, 3, 256, 256)
804+
805+
preds = v(img) # (1, 1000)
806+
```
807+
775808
## Simple Masked Image Modeling
776809

777810
<img src="./images/simmim.png" width="400px"/>
@@ -2029,4 +2062,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
20292062
}
20302063
```
20312064

2065+
```bibtex
2066+
@inproceedings{ElNouby2021XCiTCI,
2067+
title = {XCiT: Cross-Covariance Image Transformers},
2068+
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
2069+
booktitle = {Neural Information Processing Systems},
2070+
year = {2021},
2071+
url = {https://api.semanticscholar.org/CorpusID:235458262}
2072+
}
2073+
```
2074+
20322075
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.5.3',
6+
version = '1.6.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/xcit.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
from random import randrange
2+
3+
import torch
4+
from torch import nn, einsum
5+
from torch.nn import Module, ModuleList
6+
import torch.nn.functional as F
7+
8+
from einops import rearrange, repeat, pack, unpack
9+
from einops.layers.torch import Rearrange
10+
11+
# helpers
12+
13+
def exists(val):
14+
return val is not None
15+
16+
def pack_one(t, pattern):
17+
return pack([t], pattern)
18+
19+
def unpack_one(t, ps, pattern):
20+
return unpack(t, ps, pattern)[0]
21+
22+
def l2norm(t):
23+
return F.normalize(t, dim = -1, p = 2)
24+
25+
def dropout_layers(layers, dropout):
26+
if dropout == 0:
27+
return layers
28+
29+
num_layers = len(layers)
30+
to_drop = torch.zeros(num_layers).uniform_(0., 1.) < dropout
31+
32+
# make sure at least one layer makes it
33+
if all(to_drop):
34+
rand_index = randrange(num_layers)
35+
to_drop[rand_index] = False
36+
37+
layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
38+
return layers
39+
40+
# classes
41+
42+
class LayerScale(Module):
43+
def __init__(self, dim, fn, depth):
44+
super().__init__()
45+
if depth <= 18:
46+
init_eps = 0.1
47+
elif 18 > depth <= 24:
48+
init_eps = 1e-5
49+
else:
50+
init_eps = 1e-6
51+
52+
self.fn = fn
53+
self.scale = nn.Parameter(torch.full((dim,), init_eps))
54+
55+
def forward(self, x, **kwargs):
56+
return self.fn(x, **kwargs) * self.scale
57+
58+
class FeedForward(Module):
59+
def __init__(self, dim, hidden_dim, dropout = 0.):
60+
super().__init__()
61+
self.net = nn.Sequential(
62+
nn.LayerNorm(dim),
63+
nn.Linear(dim, hidden_dim),
64+
nn.GELU(),
65+
nn.Dropout(dropout),
66+
nn.Linear(hidden_dim, dim),
67+
nn.Dropout(dropout)
68+
)
69+
def forward(self, x):
70+
return self.net(x)
71+
72+
class Attention(Module):
73+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
74+
super().__init__()
75+
inner_dim = dim_head * heads
76+
self.heads = heads
77+
self.scale = dim_head ** -0.5
78+
79+
self.norm = nn.LayerNorm(dim)
80+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
81+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
82+
83+
self.attend = nn.Softmax(dim = -1)
84+
self.dropout = nn.Dropout(dropout)
85+
86+
self.to_out = nn.Sequential(
87+
nn.Linear(inner_dim, dim),
88+
nn.Dropout(dropout)
89+
)
90+
91+
def forward(self, x, context = None):
92+
h = self.heads
93+
94+
x = self.norm(x)
95+
context = x if not exists(context) else torch.cat((x, context), dim = 1)
96+
97+
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
98+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
99+
100+
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
101+
102+
attn = self.attend(sim)
103+
attn = self.dropout(attn)
104+
105+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
106+
out = rearrange(out, 'b h n d -> b n (h d)')
107+
return self.to_out(out)
108+
109+
class XCAttention(Module):
110+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
111+
super().__init__()
112+
inner_dim = dim_head * heads
113+
self.heads = heads
114+
self.norm = nn.LayerNorm(dim)
115+
116+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
117+
118+
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
119+
120+
self.attend = nn.Softmax(dim = -1)
121+
self.dropout = nn.Dropout(dropout)
122+
123+
self.to_out = nn.Sequential(
124+
nn.Linear(inner_dim, dim),
125+
nn.Dropout(dropout)
126+
)
127+
128+
def forward(self, x):
129+
h = self.heads
130+
x, ps = pack_one(x, 'b * d')
131+
132+
x = self.norm(x)
133+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
134+
135+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h d n', h = h), (q, k, v))
136+
137+
q, k = map(l2norm, (q, k))
138+
139+
sim = einsum('b h i n, b h j n -> b h i j', q, k) * self.temperature.exp()
140+
141+
attn = self.attend(sim)
142+
attn = self.dropout(attn)
143+
144+
out = einsum('b h i j, b h j n -> b h i n', attn, v)
145+
out = rearrange(out, 'b h d n -> b n (h d)')
146+
147+
out = unpack_one(out, ps, 'b * d')
148+
return self.to_out(out)
149+
150+
class LocalPatchInteraction(Module):
151+
def __init__(self, dim, kernel_size = 3):
152+
super().__init__()
153+
assert (kernel_size % 2) == 1
154+
padding = kernel_size // 2
155+
156+
self.net = nn.Sequential(
157+
nn.LayerNorm(dim),
158+
Rearrange('b h w c -> b c h w'),
159+
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
160+
nn.BatchNorm2d(dim),
161+
nn.GELU(),
162+
nn.Conv2d(dim, dim, kernel_size, padding = padding, groups = dim),
163+
Rearrange('b c h w -> b h w c'),
164+
)
165+
166+
def forward(self, x):
167+
return self.net(x)
168+
169+
class Transformer(Module):
170+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dropout = 0.):
171+
super().__init__()
172+
self.layers = ModuleList([])
173+
self.layer_dropout = layer_dropout
174+
175+
for ind in range(depth):
176+
layer = ind + 1
177+
self.layers.append(ModuleList([
178+
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
179+
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
180+
]))
181+
182+
def forward(self, x, context = None):
183+
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
184+
185+
for attn, ff in layers:
186+
x = attn(x, context = context) + x
187+
x = ff(x) + x
188+
189+
return x
190+
191+
class XCATransformer(Module):
192+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size = 3, dropout = 0., layer_dropout = 0.):
193+
super().__init__()
194+
self.layers = ModuleList([])
195+
self.layer_dropout = layer_dropout
196+
197+
for ind in range(depth):
198+
layer = ind + 1
199+
self.layers.append(ModuleList([
200+
LayerScale(dim, XCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = layer),
201+
LayerScale(dim, LocalPatchInteraction(dim, local_patch_kernel_size), depth = layer),
202+
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = layer)
203+
]))
204+
205+
def forward(self, x):
206+
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
207+
208+
for cross_covariance_attn, local_patch_interaction, ff in layers:
209+
x = cross_covariance_attn(x) + x
210+
x = local_patch_interaction(x) + x
211+
x = ff(x) + x
212+
213+
return x
214+
215+
class XCiT(Module):
216+
def __init__(
217+
self,
218+
*,
219+
image_size,
220+
patch_size,
221+
num_classes,
222+
dim,
223+
depth,
224+
cls_depth,
225+
heads,
226+
mlp_dim,
227+
dim_head = 64,
228+
dropout = 0.,
229+
emb_dropout = 0.,
230+
local_patch_kernel_size = 3,
231+
layer_dropout = 0.
232+
):
233+
super().__init__()
234+
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
235+
236+
num_patches = (image_size // patch_size) ** 2
237+
patch_dim = 3 * patch_size ** 2
238+
239+
self.to_patch_embedding = nn.Sequential(
240+
Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_size, p2 = patch_size),
241+
nn.LayerNorm(patch_dim),
242+
nn.Linear(patch_dim, dim),
243+
nn.LayerNorm(dim)
244+
)
245+
246+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
247+
self.cls_token = nn.Parameter(torch.randn(dim))
248+
249+
self.dropout = nn.Dropout(emb_dropout)
250+
251+
self.xcit_transformer = XCATransformer(dim, depth, heads, dim_head, mlp_dim, local_patch_kernel_size, dropout, layer_dropout)
252+
253+
self.final_norm = nn.LayerNorm(dim)
254+
255+
self.cls_transformer = Transformer(dim, cls_depth, heads, dim_head, mlp_dim, dropout, layer_dropout)
256+
257+
self.mlp_head = nn.Sequential(
258+
nn.LayerNorm(dim),
259+
nn.Linear(dim, num_classes)
260+
)
261+
262+
def forward(self, img):
263+
x = self.to_patch_embedding(img)
264+
265+
x, ps = pack_one(x, 'b * d')
266+
267+
b, n, _ = x.shape
268+
x += self.pos_embedding[:, :n]
269+
270+
x = unpack_one(x, ps, 'b * d')
271+
272+
x = self.dropout(x)
273+
274+
x = self.xcit_transformer(x)
275+
276+
x = self.final_norm(x)
277+
278+
cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b = b)
279+
280+
x = rearrange(x, 'b ... d -> b (...) d')
281+
cls_tokens = self.cls_transformer(cls_tokens, context = x)
282+
283+
return self.mlp_head(cls_tokens[:, 0])

0 commit comments

Comments
 (0)