Skip to content

Commit 0b5c9b4

Browse files
committed
add value residual based simple vit
1 parent e300cdd commit 0b5c9b4

File tree

3 files changed

+161
-1
lines changed

3 files changed

+161
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,4 +2152,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
21522152
}
21532153
```
21542154

2155+
```bibtex
2156+
@inproceedings{Zhou2024ValueRL,
2157+
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
2158+
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
2159+
year = {2024},
2160+
url = {https://api.semanticscholar.org/CorpusID:273532030}
2161+
}
2162+
```
2163+
21552164
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.8.5',
9+
version = '1.8.6',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import Module, ModuleList
4+
5+
from einops import rearrange
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def exists(v):
11+
return v is not None
12+
13+
def default(v, d):
14+
return v if exists(v) else d
15+
16+
def pair(t):
17+
return t if isinstance(t, tuple) else (t, t)
18+
19+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
20+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
21+
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
22+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
23+
omega = 1.0 / (temperature ** omega)
24+
25+
y = y.flatten()[:, None] * omega[None, :]
26+
x = x.flatten()[:, None] * omega[None, :]
27+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
28+
return pe.type(dtype)
29+
30+
# classes
31+
32+
def FeedForward(dim, hidden_dim):
33+
return nn.Sequential(
34+
nn.LayerNorm(dim),
35+
nn.Linear(dim, hidden_dim),
36+
nn.GELU(),
37+
nn.Linear(hidden_dim, dim),
38+
)
39+
40+
class Attention(Module):
41+
def __init__(self, dim, heads = 8, dim_head = 64):
42+
super().__init__()
43+
inner_dim = dim_head * heads
44+
self.heads = heads
45+
self.scale = dim_head ** -0.5
46+
self.norm = nn.LayerNorm(dim)
47+
48+
self.attend = nn.Softmax(dim = -1)
49+
50+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
51+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
52+
53+
def forward(self, x, value_residual = None):
54+
x = self.norm(x)
55+
56+
qkv = self.to_qkv(x).chunk(3, dim = -1)
57+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58+
59+
if exists(value_residual):
60+
v = v + value_residual
61+
62+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
63+
64+
attn = self.attend(dots)
65+
66+
out = torch.matmul(attn, v)
67+
out = rearrange(out, 'b h n d -> b n (h d)')
68+
69+
return self.to_out(out), v
70+
71+
class Transformer(Module):
72+
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
73+
super().__init__()
74+
self.norm = nn.LayerNorm(dim)
75+
self.layers = ModuleList([])
76+
for _ in range(depth):
77+
self.layers.append(ModuleList([
78+
Attention(dim, heads = heads, dim_head = dim_head),
79+
FeedForward(dim, mlp_dim)
80+
]))
81+
def forward(self, x):
82+
value_residual = None
83+
84+
for attn, ff in self.layers:
85+
86+
attn_out, values = attn(x, value_residual = value_residual)
87+
value_residual = default(value_residual, values)
88+
89+
x = attn_out + x
90+
x = ff(x) + x
91+
92+
return self.norm(x)
93+
94+
class SimpleViT(Module):
95+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
96+
super().__init__()
97+
image_height, image_width = pair(image_size)
98+
patch_height, patch_width = pair(patch_size)
99+
100+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
101+
102+
patch_dim = channels * patch_height * patch_width
103+
104+
self.to_patch_embedding = nn.Sequential(
105+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
106+
nn.LayerNorm(patch_dim),
107+
nn.Linear(patch_dim, dim),
108+
nn.LayerNorm(dim),
109+
)
110+
111+
self.pos_embedding = posemb_sincos_2d(
112+
h = image_height // patch_height,
113+
w = image_width // patch_width,
114+
dim = dim,
115+
)
116+
117+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
118+
119+
self.pool = "mean"
120+
self.to_latent = nn.Identity()
121+
122+
self.linear_head = nn.Linear(dim, num_classes)
123+
124+
def forward(self, img):
125+
device = img.device
126+
127+
x = self.to_patch_embedding(img)
128+
x += self.pos_embedding.to(device, dtype=x.dtype)
129+
130+
x = self.transformer(x)
131+
x = x.mean(dim = 1)
132+
133+
x = self.to_latent(x)
134+
return self.linear_head(x)
135+
136+
# quick test
137+
138+
if __name__ == '__main__':
139+
v = SimpleViT(
140+
num_classes = 1000,
141+
image_size = 256,
142+
patch_size = 8,
143+
dim = 1024,
144+
depth = 6,
145+
heads = 8,
146+
mlp_dim = 2048,
147+
)
148+
149+
images = torch.randn(2, 3, 256, 256)
150+
151+
logits = v(images)

0 commit comments

Comments
 (0)