Skip to content

Commit

Permalink
add scalable vit, from bytedance AI
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 23, 2022
1 parent cb22cbb commit d27721a
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 1 deletion.
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- [Twins SVT](#twins-svt)
- [CrossFormer](#crossformer)
- [RegionViT](#regionvit)
- [ScalableViT](#scalablevit)
- [NesT](#nest)
- [MobileViT](#mobilevit)
- [Masked Autoencoder](#masked-autoencoder)
Expand Down Expand Up @@ -525,6 +526,38 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```

## ScalableViT

<img src="./images/scalable-vit-1.png" width="400px"></img>

<img src="./images/scalable-vit-2.png" width="400px"></img>

This Bytedance AI <a href="https://arxiv.org/abs/2203.10790">paper</a> proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (`reduction_factor`), while modulating the dimension of the queries and keys (`ssa_dim_key`). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).

They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.

You can use it as follows (ex. ScalableViT-S)

```python
import torch
from vit_pytorch.scalable_vit import ScalableViT

model = ScalableViT(
num_classes = 1000,
dim = 64, # starting model dimension. at every stage, dimension is doubled
heads = (2, 4, 8, 16), # number of attention heads at each stage
depth = (2, 2, 20, 2), # number of transformer blocks at each stage
ssa_dim_key = (40, 40, 40, 32), # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
reduction_factor = (8, 4, 2, 1), # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
window_size = (64, 32, None, None), # window size of the IWSA at each stage. None means no windowing needed
dropout = 0.1, # attention and feedforward dropout
).cuda()

img = torch.randn(1, 3, 256, 256).cuda()

preds = model(img) # (1, 1000)
```

## NesT

<img src="./images/nest.png" width="400px"></img>
Expand Down Expand Up @@ -1352,6 +1385,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{yang2022scalablevit,
title = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer},
author = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
year = {2022},
eprint = {2203.10790},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
Binary file added images/scalable-vit-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/scalable-vit-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.27.1',
version = '0.28.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
302 changes: 302 additions & 0 deletions vit_pytorch/scalable_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
from functools import partial
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)

# helper classes

class ChanLayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

def forward(self, x):
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanLayerNorm(dim)
self.fn = fn

def forward(self, x):
return self.fn(self.norm(x))

class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

def forward(self, x):
return self.conv(x)

class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

def forward(self, x):
return self.proj(x) + x

# feedforward

class FeedForward(nn.Module):
def __init__(self, dim, expansion_factor = 4, dropout = 0.):
super().__init__()
inner_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

# attention

class ScalableSelfAttention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_key = 64,
dim_value = 64,
dropout = 0.,
reduction_factor = 1
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.attend = nn.Softmax(dim = -1)

self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, reduction_factor, stride = reduction_factor, bias = False)

self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)

def forward(self, x):
height, width, heads = *x.shape[-2:], self.heads

q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

# split out heads

q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))

# similarity

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

# attention

attn = self.attend(dots)

# aggregate values

out = torch.matmul(attn, v)

# merge back heads

out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = height, y = width)
return self.to_out(out)

class InteractiveWindowedSelfAttention(nn.Module):
def __init__(
self,
dim,
window_size,
heads = 8,
dim_key = 64,
dim_value = 64,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_key ** -0.5
self.window_size = window_size
self.attend = nn.Softmax(dim = -1)

self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)

self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_v = nn.Conv2d(dim, dim_value * heads, 1, bias = False)

self.to_out = nn.Sequential(
nn.Conv2d(dim_value * heads, dim, 1),
nn.Dropout(dropout)
)

def forward(self, x):
height, width, heads, wsz = *x.shape[-2:], self.heads, self.window_size

wsz = default(wsz, height) # take height as window size if not given
assert (height % wsz) == 0 and (width % wsz) == 0, f'height ({height}) or width ({width}) of feature map is not divisible by the window size ({wsz})'

q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

# get output of LIM

local_out = self.local_interactive_module(v)

# divide into window (and split out heads) for efficient self attention

q, k, v = map(lambda t: rearrange(t, 'b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d', h = heads, w1 = wsz, w2 = wsz), (q, k, v))

# similarity

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

# attention

attn = self.attend(dots)

# aggregate values

out = torch.matmul(attn, v)

# reshape the windows back to full feature map (and merge heads)

out = rearrange(out, '(b x y) h (w1 w2) d -> b (h d) (x w1) (y w2)', x = height // wsz, y = width // wsz, w1 = wsz, w2 = wsz)

# add LIM output

out = out + local_out

return self.to_out(out)

class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
heads = 8,
ff_expansion_factor = 4,
dropout = 0.,
ssa_dim_key = 64,
ssa_dim_value = 64,
ssa_reduction_factor = 1,
iwsa_dim_key = 64,
iwsa_dim_value = 64,
iwsa_window_size = 64,
norm_output = True
):
super().__init__()
self.layers = nn.ModuleList([])
for ind in range(depth):
is_first = ind == 0

self.layers.append(nn.ModuleList([
PreNorm(dim, ScalableSelfAttention(dim, heads = heads, dim_key = ssa_dim_key, dim_value = ssa_dim_value, reduction_factor = ssa_reduction_factor, dropout = dropout)),
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PEG(dim) if is_first else None,
PreNorm(dim, FeedForward(dim, expansion_factor = ff_expansion_factor, dropout = dropout)),
PreNorm(dim, InteractiveWindowedSelfAttention(dim, heads = heads, dim_key = iwsa_dim_key, dim_value = iwsa_dim_value, window_size = iwsa_window_size, dropout = dropout))
]))

self.norm = ChanLayerNorm(dim) if norm_output else nn.Identity()

def forward(self, x):
for ssa, ff1, peg, iwsa, ff2 in self.layers:
x = ssa(x) + x
x = ff1(x) + x

if exists(peg):
x = peg(x)

x = iwsa(x) + x
x = ff2(x) + x

return self.norm(x)

class ScalableViT(nn.Module):
def __init__(
self,
*,
num_classes,
dim,
depth,
heads,
reduction_factor,
ff_expansion_factor = 4,
iwsa_dim_key = 64,
iwsa_dim_value = 64,
window_size = 64,
ssa_dim_key = 64,
ssa_dim_value = 64,
channels = 3,
dropout = 0.
):
super().__init__()
self.to_patches = nn.Conv2d(channels, dim, 7, stride = 4, padding = 3)

assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'

num_stages = len(depth)
dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))

hyperparams_per_stage = [
heads,
ssa_dim_key,
ssa_dim_value,
reduction_factor,
iwsa_dim_key,
iwsa_dim_value,
window_size,
]

hyperparams_per_stage = list(map(partial(cast_tuple, length = num_stages), hyperparams_per_stage))
assert all(tuple(map(lambda arr: len(arr) == num_stages, hyperparams_per_stage)))

self.layers = nn.ModuleList([])

for ind, (layer_dim, layer_depth, layer_heads, layer_ssa_dim_key, layer_ssa_dim_value, layer_ssa_reduction_factor, layer_iwsa_dim_key, layer_iwsa_dim_value, layer_window_size) in enumerate(zip(dims, depth, *hyperparams_per_stage)):
is_last = ind == (num_stages - 1)

self.layers.append(nn.ModuleList([
Transformer(dim = layer_dim, depth = layer_depth, heads = layer_heads, ff_expansion_factor = ff_expansion_factor, dropout = dropout, ssa_dim_key = layer_ssa_dim_key, ssa_dim_value = layer_ssa_dim_value, ssa_reduction_factor = layer_ssa_reduction_factor, iwsa_dim_key = layer_iwsa_dim_key, iwsa_dim_value = layer_iwsa_dim_value, iwsa_window_size = layer_window_size),
Downsample(layer_dim, layer_dim * 2) if not is_last else None
]))

self.mlp_head = nn.Sequential(
Reduce('b d h w -> b d', 'mean'),
nn.LayerNorm(dims[-1]),
nn.Linear(dims[-1], num_classes)
)

def forward(self, img):
x = self.to_patches(img)

for transformer, downsample in self.layers:
x = transformer(x)

if exists(downsample):
x = downsample(x)

return self.mlp_head(x)

0 comments on commit d27721a

Please sign in to comment.