Skip to content

Commit 73199ab

Browse files
committed
Nested navit (#325)
add a variant of NaViT using nested tensors
1 parent 4f22eae commit 73199ab

File tree

4 files changed

+367
-3
lines changed

4 files changed

+367
-3
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,38 @@ preds = v(
198198
) # (5, 1000)
199199
```
200200

201+
Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows
202+
203+
```python
204+
import torch
205+
from vit_pytorch.na_vit_nested_tensor import NaViT
206+
207+
v = NaViT(
208+
image_size = 256,
209+
patch_size = 32,
210+
num_classes = 1000,
211+
dim = 1024,
212+
depth = 6,
213+
heads = 16,
214+
mlp_dim = 2048,
215+
dropout = 0.,
216+
emb_dropout = 0.,
217+
token_dropout_prob = 0.1
218+
)
219+
220+
# 5 images of different resolutions - List[Tensor]
221+
222+
images = [
223+
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
224+
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
225+
torch.randn(3, 64, 256)
226+
]
227+
228+
preds = v(images)
229+
230+
assert preds.shape == (5, 1000)
231+
```
232+
201233
## Distillation
202234

203235
<img src="./images/distill.png" width="300px"></img>

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.7.5',
9+
version = '1.7.7',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/na_vit.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from functools import partial
2-
from typing import List, Union
4+
from typing import List
35

46
import torch
57
import torch.nn.functional as F
@@ -245,7 +247,7 @@ def device(self):
245247

246248
def forward(
247249
self,
248-
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
250+
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
249251
group_images = False,
250252
group_max_seq_len = 2048
251253
):
@@ -264,6 +266,11 @@ def forward(
264266
max_seq_len = group_max_seq_len
265267
)
266268

269+
# if List[Tensor] is not grouped -> List[List[Tensor]]
270+
271+
if torch.is_tensor(batched_images[0]):
272+
batched_images = [batched_images]
273+
267274
# process images into variable lengthed sequences with attention mask
268275

269276
num_images = []

0 commit comments

Comments
 (0)