1
1
from functools import partial
2
- from typing import List
2
+ from typing import List , Union
3
3
4
4
import torch
5
5
import torch .nn .functional as F
@@ -17,12 +17,58 @@ def exists(val):
17
17
def default (val , d ):
18
18
return val if exists (val ) else d
19
19
20
+ def always (val ):
21
+ return lambda * args : val
22
+
20
23
def pair (t ):
21
24
return t if isinstance (t , tuple ) else (t , t )
22
25
23
26
def divisible_by (numer , denom ):
24
27
return (numer % denom ) == 0
25
28
29
+ # auto grouping images
30
+
31
+ def group_images_by_max_seq_len (
32
+ images : List [Tensor ],
33
+ patch_size : int ,
34
+ calc_token_dropout = None ,
35
+ max_seq_len = 2048
36
+
37
+ ) -> List [List [Tensor ]]:
38
+
39
+ calc_token_dropout = default (calc_token_dropout , always (0. ))
40
+
41
+ groups = []
42
+ group = []
43
+ seq_len = 0
44
+
45
+ if isinstance (calc_token_dropout , (float , int )):
46
+ calc_token_dropout = always (calc_token_dropout )
47
+
48
+ for image in images :
49
+ assert isinstance (image , Tensor )
50
+
51
+ image_dims = image .shape [- 2 :]
52
+ ph , pw = map (lambda t : t // patch_size , image_dims )
53
+
54
+ image_seq_len = (ph * pw )
55
+ image_seq_len = int (image_seq_len * (1 - calc_token_dropout (* image_dims )))
56
+
57
+ assert image_seq_len <= max_seq_len , f'image with dimensions { image_dims } exceeds maximum sequence length'
58
+
59
+ if (seq_len + image_seq_len ) > max_seq_len :
60
+ groups .append (group )
61
+ group = []
62
+ seq_len = 0
63
+
64
+ group .append (image )
65
+ seq_len += image_seq_len
66
+
67
+ if len (group ) > 0 :
68
+ groups .append (group )
69
+
70
+ return groups
71
+
26
72
# normalization
27
73
# they use layernorm without bias, something that pytorch does not offer
28
74
@@ -199,13 +245,25 @@ def device(self):
199
245
200
246
def forward (
201
247
self ,
202
- batched_images : List [List [Tensor ]] # assume different resolution images already grouped correctly
248
+ batched_images : Union [List [Tensor ], List [List [Tensor ]]], # assume different resolution images already grouped correctly
249
+ group_images = False ,
250
+ group_max_seq_len = 2048
203
251
):
204
252
p , c , device , has_token_dropout = self .patch_size , self .channels , self .device , exists (self .calc_token_dropout )
205
253
206
254
arange = partial (torch .arange , device = device )
207
255
pad_sequence = partial (orig_pad_sequence , batch_first = True )
208
256
257
+ # auto pack if specified
258
+
259
+ if group_images :
260
+ batched_images = group_images_by_max_seq_len (
261
+ batched_images ,
262
+ patch_size = self .patch_size ,
263
+ calc_token_dropout = self .calc_token_dropout ,
264
+ max_seq_len = group_max_seq_len
265
+ )
266
+
209
267
# process images into variable lengthed sequences with attention mask
210
268
211
269
num_images = []
0 commit comments