@@ -138,14 +138,25 @@ def forward(
138
138
return self .norm (x )
139
139
140
140
class NaViT (nn .Module ):
141
- def __init__ (self , * , image_size , patch_size , num_classes , dim , depth , heads , mlp_dim , channels = 3 , dim_head = 64 , dropout = 0. , emb_dropout = 0. , token_dropout_prob = 0. ):
141
+ def __init__ (self , * , image_size , patch_size , num_classes , dim , depth , heads , mlp_dim , channels = 3 , dim_head = 64 , dropout = 0. , emb_dropout = 0. , token_dropout_prob = None ):
142
142
super ().__init__ ()
143
143
image_height , image_width = pair (image_size )
144
144
145
145
# what percent of tokens to dropout
146
- # in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?)
146
+ # if int or float given, then assume constant dropout prob
147
+ # otherwise accept a callback that in turn calculates dropout prob from height and width
147
148
148
- self .token_dropout_prob = token_dropout_prob
149
+ self .calc_token_dropout = calc_token_dropout = None
150
+
151
+ if callable (token_dropout_prob ):
152
+ self .calc_token_dropout = token_dropout_prob
153
+
154
+ elif isinstance (token_dropout_prob , (float , int )):
155
+ assert 0. < token_dropout_prob < 1.
156
+ token_dropout_prob = float (token_dropout_prob )
157
+ self .calc_token_dropout = lambda height , width : token_dropout_prob
158
+
159
+ # calculate patching related stuff
149
160
150
161
assert divisible_by (image_height , patch_size ) and divisible_by (image_width , patch_size ), 'Image dimensions must be divisible by the patch size.'
151
162
@@ -190,7 +201,7 @@ def forward(
190
201
self ,
191
202
batched_images : List [List [Tensor ]] # assume different resolution images already grouped correctly
192
203
):
193
- p , c , device , has_token_dropout = self .patch_size , self .channels , self .device , self .token_dropout_prob > 0.
204
+ p , c , device , has_token_dropout = self .patch_size , self .channels , self .device , exists ( self .calc_token_dropout )
194
205
195
206
arange = partial (torch .arange , device = device )
196
207
pad_sequence = partial (orig_pad_sequence , batch_first = True )
@@ -227,8 +238,10 @@ def forward(
227
238
seq_len = seq .shape [- 2 ]
228
239
229
240
if has_token_dropout :
230
- num_keep = max (1 , int (seq_len * (1 - self .token_dropout_prob )))
241
+ token_dropout = self .calc_token_dropout (* image_dims )
242
+ num_keep = max (1 , int (seq_len * (1 - token_dropout )))
231
243
keep_indices = torch .randn ((seq_len ,), device = device ).topk (num_keep , dim = - 1 ).indices
244
+
232
245
seq = seq [keep_indices ]
233
246
pos = pos [keep_indices ]
234
247
0 commit comments