Skip to content

Commit 6e2393d

Browse files
committed
wrap up NaViT
1 parent 32974c3 commit 6e2393d

File tree

3 files changed

+79
-3
lines changed

3 files changed

+79
-3
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,24 @@ preds = v(images) # (5, 1000) - 5, because 5 images of different resolution abov
179179

180180
```
181181

182+
Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length
183+
184+
```python
185+
images = [
186+
torch.randn(3, 256, 256),
187+
torch.randn(3, 128, 128),
188+
torch.randn(3, 128, 256),
189+
torch.randn(3, 256, 128),
190+
torch.randn(3, 64, 256)
191+
]
192+
193+
preds = v(
194+
images,
195+
group_images = True,
196+
group_max_seq_len = 64
197+
) # (5, 1000)
198+
```
199+
182200
## Distillation
183201

184202
<img src="./images/distill.png" width="300px"></img>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '1.2.8',
6+
version = '1.2.9',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

vit_pytorch/na_vit.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
from typing import List
2+
from typing import List, Union
33

44
import torch
55
import torch.nn.functional as F
@@ -17,12 +17,58 @@ def exists(val):
1717
def default(val, d):
1818
return val if exists(val) else d
1919

20+
def always(val):
21+
return lambda *args: val
22+
2023
def pair(t):
2124
return t if isinstance(t, tuple) else (t, t)
2225

2326
def divisible_by(numer, denom):
2427
return (numer % denom) == 0
2528

29+
# auto grouping images
30+
31+
def group_images_by_max_seq_len(
32+
images: List[Tensor],
33+
patch_size: int,
34+
calc_token_dropout = None,
35+
max_seq_len = 2048
36+
37+
) -> List[List[Tensor]]:
38+
39+
calc_token_dropout = default(calc_token_dropout, always(0.))
40+
41+
groups = []
42+
group = []
43+
seq_len = 0
44+
45+
if isinstance(calc_token_dropout, (float, int)):
46+
calc_token_dropout = always(calc_token_dropout)
47+
48+
for image in images:
49+
assert isinstance(image, Tensor)
50+
51+
image_dims = image.shape[-2:]
52+
ph, pw = map(lambda t: t // patch_size, image_dims)
53+
54+
image_seq_len = (ph * pw)
55+
image_seq_len = int(image_seq_len * (1 - calc_token_dropout(*image_dims)))
56+
57+
assert image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
58+
59+
if (seq_len + image_seq_len) > max_seq_len:
60+
groups.append(group)
61+
group = []
62+
seq_len = 0
63+
64+
group.append(image)
65+
seq_len += image_seq_len
66+
67+
if len(group) > 0:
68+
groups.append(group)
69+
70+
return groups
71+
2672
# normalization
2773
# they use layernorm without bias, something that pytorch does not offer
2874

@@ -199,13 +245,25 @@ def device(self):
199245

200246
def forward(
201247
self,
202-
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
248+
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
249+
group_images = False,
250+
group_max_seq_len = 2048
203251
):
204252
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
205253

206254
arange = partial(torch.arange, device = device)
207255
pad_sequence = partial(orig_pad_sequence, batch_first = True)
208256

257+
# auto pack if specified
258+
259+
if group_images:
260+
batched_images = group_images_by_max_seq_len(
261+
batched_images,
262+
patch_size = self.patch_size,
263+
calc_token_dropout = self.calc_token_dropout,
264+
max_seq_len = group_max_seq_len
265+
)
266+
209267
# process images into variable lengthed sequences with attention mask
210268

211269
num_images = []

0 commit comments

Comments
 (0)