Skip to content

Commit bbb24e3

Browse files
committed
give a learned bias to and from registers for maxvit + register token variant
1 parent df8733d commit bbb24e3

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.5.2',
6+
version = '1.5.3',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/max_vit_with_registers.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,11 @@ def __init__(
119119
dim,
120120
dim_head = 32,
121121
dropout = 0.,
122-
window_size = 7
122+
window_size = 7,
123+
num_registers = 1
123124
):
124125
super().__init__()
126+
assert num_registers > 0
125127
assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
126128

127129
self.heads = dim // dim_head
@@ -142,7 +144,9 @@ def __init__(
142144

143145
# relative positional bias
144146

145-
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
147+
num_rel_pos_bias = (2 * window_size - 1) ** 2
148+
149+
self.rel_pos_bias = nn.Embedding(num_rel_pos_bias + 1, self.heads)
146150

147151
pos = torch.arange(window_size)
148152
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
@@ -151,10 +155,11 @@ def __init__(
151155
rel_pos += window_size - 1
152156
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
153157

158+
rel_pos_indices = F.pad(rel_pos_indices, (num_registers, 0, num_registers, 0), value = num_rel_pos_bias)
154159
self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
155160

156161
def forward(self, x):
157-
device, h = x.device, self.heads
162+
device, h, bias_indices = x.device, self.heads, self.rel_pos_indices
158163

159164
x = self.norm(x)
160165

@@ -176,13 +181,8 @@ def forward(self, x):
176181

177182
# add positional bias
178183

179-
bias = self.rel_pos_bias(self.rel_pos_indices)
180-
bias = rearrange(bias, 'i j h -> h i j')
181-
182-
num_registers = sim.shape[-1] - bias.shape[-1]
183-
bias = F.pad(bias, (num_registers, 0, num_registers, 0), value = 0.)
184-
185-
sim = sim + bias
184+
bias = self.rel_pos_bias(bias_indices)
185+
sim = sim + rearrange(bias, 'i j h -> h i j')
186186

187187
# attention
188188

@@ -215,6 +215,7 @@ def __init__(
215215
):
216216
super().__init__()
217217
assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
218+
assert num_register_tokens > 0
218219

219220
# convolutional stem
220221

@@ -256,10 +257,10 @@ def __init__(
256257
shrinkage_rate = mbconv_shrinkage_rate
257258
)
258259

259-
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
260+
block_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
260261
block_ff = FeedForward(dim = layer_dim, dropout = dropout)
261262

262-
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size)
263+
grid_attn = Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = window_size, num_registers = num_register_tokens)
263264
grid_ff = FeedForward(dim = layer_dim, dropout = dropout)
264265

265266
register_tokens = nn.Parameter(torch.randn(num_register_tokens, layer_dim))

0 commit comments

Comments
 (0)