Skip to content

Commit cd21090

Browse files
committed
one can pass a callback to token_dropout_prob for NaViT that takes in height and width and calculate appropriate dropout rate
1 parent 17675e0 commit cd21090

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.2.7',
6+
version = '1.2.8',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/na_vit.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,25 @@ def forward(
138138
return self.norm(x)
139139

140140
class NaViT(nn.Module):
141-
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = 0.):
141+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = None):
142142
super().__init__()
143143
image_height, image_width = pair(image_size)
144144

145145
# what percent of tokens to dropout
146-
# in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?)
146+
# if int or float given, then assume constant dropout prob
147+
# otherwise accept a callback that in turn calculates dropout prob from height and width
147148

148-
self.token_dropout_prob = token_dropout_prob
149+
self.calc_token_dropout = calc_token_dropout = None
150+
151+
if callable(token_dropout_prob):
152+
self.calc_token_dropout = token_dropout_prob
153+
154+
elif isinstance(token_dropout_prob, (float, int)):
155+
assert 0. < token_dropout_prob < 1.
156+
token_dropout_prob = float(token_dropout_prob)
157+
self.calc_token_dropout = lambda height, width: token_dropout_prob
158+
159+
# calculate patching related stuff
149160

150161
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
151162

@@ -190,7 +201,7 @@ def forward(
190201
self,
191202
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
192203
):
193-
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, self.token_dropout_prob > 0.
204+
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
194205

195206
arange = partial(torch.arange, device = device)
196207
pad_sequence = partial(orig_pad_sequence, batch_first = True)
@@ -227,8 +238,10 @@ def forward(
227238
seq_len = seq.shape[-2]
228239

229240
if has_token_dropout:
230-
num_keep = max(1, int(seq_len * (1 - self.token_dropout_prob)))
241+
token_dropout = self.calc_token_dropout(*image_dims)
242+
num_keep = max(1, int(seq_len * (1 - token_dropout)))
231243
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
244+
232245
seq = seq[keep_indices]
233246
pos = pos[keep_indices]
234247

0 commit comments

Comments
 (0)