Skip to content

Commit

Permalink
begin work on NaViT (lucidrains#273)
Browse files Browse the repository at this point in the history
finish core idea of NaViT
  • Loading branch information
lucidrains authored Jul 24, 2023
1 parent e9ca1f4 commit 23820bc
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 0 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- [Usage](#usage)
- [Parameters](#parameters)
- [Simple ViT](#simple-vit)
- [NaViT](#na-vit)
- [Distillation](#distillation)
- [Deep ViT](#deep-vit)
- [CaiT](#cait)
Expand Down Expand Up @@ -139,6 +140,44 @@ img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```

## NaViT

<img src="./images/na_vit.png" width="450px"></img>

<a href="https://arxiv.org/abs/2307.06304">This paper</a> proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.

You can use it as follows

```python
import torch
from vit_pytorch.na_vit import NaViT

v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

# 5 images of different resolutions - List[List[Tensor]]

# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking

images = [
[torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
[torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
[torch.randn(3, 64, 256)]
]

preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above

```

## Distillation

<img src="./images/distill.png" width="300px"></img>
Expand Down Expand Up @@ -1934,6 +1973,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@inproceedings{Dehghani2023PatchNP,
title = {Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution},
author = {Mostafa Dehghani and Basil Mustafa and Josip Djolonga and Jonathan Heek and Matthias Minderer and Mathilde Caron and Andreas Steiner and Joan Puigcerver and Robert Geirhos and Ibrahim M. Alabdulmohsin and Avital Oliver and Piotr Padlewski and Alexey A. Gritsenko and Mario Luvci'c and Neil Houlsby},
year = {2023}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
Binary file added images/navit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
305 changes: 305 additions & 0 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
from functools import partial
from typing import List

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def divisible_by(numer, denom):
return (numer % denom) == 0

# normalization
# they use layernorm without bias, something that pytorch does not offer

class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))

def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper

class RMSNorm(nn.Module):
def __init__(self, heads, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))

def forward(self, x):
normed = F.normalize(x, dim = -1)
return normed * self.scale * self.gamma

# feedforward

def FeedForward(dim, hidden_dim, dropout = 0.):
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.norm = LayerNorm(dim)

self.q_norm = RMSNorm(heads, dim_head)
self.k_norm = RMSNorm(heads, dim_head)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.Dropout(dropout)
)

def forward(
self,
x,
context = None,
mask = None,
attn_mask = None
):
x = self.norm(x)
kv_input = default(context, x)

qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1))

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

q = self.q_norm(q)
k = self.k_norm(k)

dots = torch.matmul(q, k.transpose(-1, -2))

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max)

if exists(attn_mask):
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

self.norm = LayerNorm(dim)

def forward(
self,
x,
mask = None,
attn_mask = None
):
for attn, ff in self.layers:
x = attn(x, mask = mask, attn_mask = attn_mask) + x
x = ff(x) + x

return self.norm(x)

class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)

assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
patch_dim = channels * (patch_size ** 2)

self.channels = channels
self.patch_size = patch_size

self.to_patch_embedding = nn.Sequential(
LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
LayerNorm(dim),
)

self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))

self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

# final attention pooling queries

self.attn_pool_queries = nn.Parameter(torch.randn(dim))
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)

# output to logits

self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, num_classes, bias = False)
)

@property
def device(self):
return next(self.parameters()).device

def forward(
self,
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
):
p, c, device = self.patch_size, self.channels, self.device

arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)

# process images into variable lengthed sequences with attention mask

num_images = []
batched_sequences = []
batched_positions = []
batched_image_ids = []

for images in batched_images:
num_images.append(len(images))

sequences = []
positions = []
image_ids = torch.empty((0,), device = device, dtype = torch.long)

for image_id, image in enumerate(images):
assert image.ndim ==3 and image.shape[0] == c
image_dims = image.shape[-2:]
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}'

ph, pw = map(lambda dim: dim // p, image_dims)

pos = torch.stack(torch.meshgrid((
arange(ph),
arange(pw)
), indexing = 'ij'), dim = -1)

pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)

image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)

batched_image_ids.append(image_ids)
batched_sequences.append(torch.cat(sequences, dim = 0))
batched_positions.append(torch.cat(positions, dim = 0))

# derive key padding mask

lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
max_length = arange(lengths.amax().item())
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')

# derive attention mask, and combine with key padding mask from above

batched_image_ids = pad_sequence(batched_image_ids)
attn_mask = rearrange(batched_image_ids, 'b i -> b 1 i 1') == rearrange(batched_image_ids, 'b j -> b 1 1 j')
attn_mask = attn_mask & rearrange(key_pad_mask, 'b j -> b 1 1 j')

# combine patched images as well as the patched width / height positions for 2d positional embedding

patches = pad_sequence(batched_sequences)
patch_positions = pad_sequence(batched_positions)

# need to know how many images for final attention pooling

num_images = torch.tensor(num_images, device = device, dtype = torch.long)

# to patches

x = self.to_patch_embedding(patches)

# factorized 2d absolute positional embedding

h_indices, w_indices = patch_positions.unbind(dim = -1)

h_pos = self.pos_embed_height[h_indices]
w_pos = self.pos_embed_width[w_indices]

x = x + h_pos + w_pos

# embed dropout

x = self.dropout(x)

# attention

x = self.transformer(x, attn_mask = attn_mask)

# do attention pooling at the end

max_queries = num_images.amax().item()

queries = repeat(self.attn_pool_queries, 'd -> b n d', n = max_queries, b = x.shape[0])

# attention pool mask

image_id_arange = arange(max_queries)

attn_pool_mask = rearrange(image_id_arange, 'i -> i 1') == rearrange(batched_image_ids, 'b j -> b 1 j')

attn_pool_mask = attn_pool_mask & rearrange(key_pad_mask, 'b j -> b 1 j')

attn_pool_mask = rearrange(attn_pool_mask, 'b i j -> b 1 i j')

# attention pool

x = self.attn_pool(queries, context = x, attn_mask = attn_pool_mask) + queries

x = rearrange(x, 'b n d -> (b n) d')

# each batch element may not have same amount of images

is_images = image_id_arange < rearrange(num_images, 'b -> b 1')
is_images = rearrange(is_images, 'b n -> (b n)')

x = x[is_images]

# project out to logits

x = self.to_latent(x)

return self.mlp_head(x)

0 comments on commit 23820bc

Please sign in to comment.