Skip to content

Commit

Permalink
Nested navit (lucidrains#325)
Browse files Browse the repository at this point in the history
add a variant of NaViT using nested tensors
  • Loading branch information
lucidrains committed Aug 20, 2024
1 parent 4f22eae commit 73199ab
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 3 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,38 @@ preds = v(
) # (5, 1000)
```

Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows

```python
import torch
from vit_pytorch.na_vit_nested_tensor import NaViT

v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob = 0.1
)

# 5 images of different resolutions - List[Tensor]

images = [
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]

preds = v(images)

assert preds.shape == (5, 1000)
```

## Distillation

<img src="./images/distill.png" width="300px"></img>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.5',
version = '1.7.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
11 changes: 9 additions & 2 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import List, Union
from typing import List

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -245,7 +247,7 @@ def device(self):

def forward(
self,
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
):
Expand All @@ -264,6 +266,11 @@ def forward(
max_seq_len = group_max_seq_len
)

# if List[Tensor] is not grouped -> List[List[Tensor]]

if torch.is_tensor(batched_images[0]):
batched_images = [batched_images]

# process images into variable lengthed sequences with attention mask

num_images = []
Expand Down
Loading

0 comments on commit 73199ab

Please sign in to comment.