Skip to content

Commit

Permalink
add ViT for small datasets https://arxiv.org/abs/2112.13492
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2021
1 parent e52ac41 commit 70ba532
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 3 deletions.
58 changes: 58 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
- [Research Ideas](#research-ideas)
Expand Down Expand Up @@ -739,6 +740,52 @@ preds = v(img) # (1, 1000)
preds, token_ids = v(img, return_sampled_token_ids = True) # (1, 1000), (1, <=8)
```

## Vision Transformer for Small Datasets

<img src="./images/vit_for_small_datasets.png" width="400px"></img>

This paper proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the `LRA` with the learned temperature and masking out of token attention to itself.

You can use as follows:

```python
import torch
from vit_pytorch.vit_for_small_dataset import ViT

v = ViT(
image_size = 256,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (1, 1000)
```

You can also use the `SPT` from this paper as a standalone module

```python
import torch
from vit_pytorch.vit_for_small_dataset import SPT

spt = SPT(
dim = 1024,
patch_size = 16,
channels = 3
)

img = torch.randn(4, 3, 256, 256)

tokens = spt(img) # (4, 256, 1024)
```

## Dino

<img src="./images/dino.png" width="350px"></img>
Expand Down Expand Up @@ -1236,6 +1283,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{lee2021vision,
title = {Vision Transformer for Small-Size Datasets},
author = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
year = {2021},
eprint = {2112.13492},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
Binary file added images/vit_for_small_datasets.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.25.6',
version = '0.26.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
7 changes: 5 additions & 2 deletions vit_pytorch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
vit,
device = None,
layer_name = 'transformer',
layer_save_input = False,
return_embeddings_only = False
):
super().__init__()
Expand All @@ -23,10 +24,12 @@ def __init__(
self.device = device

self.layer_name = layer_name
self.layer_save_input = layer_save_input # whether to save input or output of layer
self.return_embeddings_only = return_embeddings_only

def _hook(self, _, input, output):
self.latents = output.clone().detach()
def _hook(self, _, inputs, output):
tensor_to_save = inputs if self.layer_save_input else output
self.latents = tensor_to_save.clone().detach()

def _register_hook(self):
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
Expand Down
142 changes: 142 additions & 0 deletions vit_pytorch/vit_for_small_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

class LSA(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))

self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.temperature.exp()

mask = torch.eye(dots.shape[-1], device = dots.device, dtype = torch.bool)
mask_value = -torch.finfo(dots.dtype).max
dots = dots.masked_fill(mask, mask_value)

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, LSA(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x

class SPT(nn.Module):
def __init__(self, *, dim, patch_size, channels = 3):
super().__init__()
patch_dim = patch_size * patch_size * 5 * channels

self.to_patch_tokens = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim)
)

def forward(self, x):
shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
return self.to_patch_tokens(x_with_shifts)

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.to_patch_embedding = SPT(dim = dim, patch_size = patch_size, channels = channels)

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)

x = self.transformer(x)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)

0 comments on commit 70ba532

Please sign in to comment.