Skip to content

Commit

Permalink
add RegionViT paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 7, 2021
1 parent f196d1e commit 06d3753
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 1 deletion.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,33 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```

## RegionViT

<img src="./images/regionvit.png" width="400px"></img>

<img src="./images/regionvit2.png" width="400px"></img>

<a href="https://arxiv.org/abs/2106.02689">This paper</a> proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.

You can use it as follows

```python
import torch
from vit_pytorch.regionvit import RegionViT

model = RegionViT(
dim = (64, 128, 256, 512), # tuple of size 4, indicating dimension at each stage
depth = (2, 2, 8, 2), # depth of the region to local transformer at each stage
window_size = 7, # window size, which should be either 7 or 14
num_classes = 1000, # number of output lcasses
tokenize_local_3_conv = False, # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
use_peg = False, # whether to use positional generating module. they used this for object detection for a boost in performance
)

x = torch.randn(1, 3, 224, 224)
logits = model(x) # (1, 1000)
```

## NesT

<img src="./images/nest.png" width="400px"></img>
Expand Down Expand Up @@ -892,6 +919,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{chen2021regionvit,
title = {RegionViT: Regional-to-Local Attention for Vision Transformers},
author = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
year = {2021},
eprint = {2106.02689},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
Expand Down
Binary file added images/regionvit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/regionvit2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.20.8',
version = '0.21.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
268 changes: 268 additions & 0 deletions vit_pytorch/regionvit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce
import torch.nn.functional as F

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)

def divisible_by(val, d):
return (val % d) == 0

# helper classes

class Downsample(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.conv = nn.Conv2d(dim_in, dim_out, 3, stride = 2, padding = 1)

def forward(self, x):
return self.conv(x)

class PEG(nn.Module):
def __init__(self, dim, kernel_size = 3):
super().__init__()
self.proj = nn.Conv2d(dim, dim, kernel_size = kernel_size, padding = kernel_size // 2, groups = dim, stride = 1)

def forward(self, x):
return self.proj(x) + x

# transformer classes

def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim, 1)
)

class Attention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32,
dropout = 0.
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads

self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, rel_pos_bias = None):
h = self.heads

# prenorm

x = self.norm(x)

q, k, v = self.to_qkv(x).chunk(3, dim = -1)

# split heads

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale

sim = einsum('b h i d, b h j d -> b h i j', q, k)

# add relative positional bias for local tokens

if exists(rel_pos_bias):
sim = sim + rel_pos_bias

attn = sim.softmax(dim = -1)

# merge heads

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class R2LTransformer(nn.Module):
def __init__(
self,
dim,
*,
window_size,
depth = 4,
heads = 4,
dim_head = 32,
attn_dropout = 0.,
ff_dropout = 0.,
):
super().__init__()
self.layers = nn.ModuleList([])

self.window_size = window_size
rel_positions = 2 * window_size - 1
self.local_rel_pos_bias = nn.Embedding(rel_positions ** 2, heads)

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout)
]))

def forward(self, local_tokens, region_tokens):
device = local_tokens.device
lh, lw = local_tokens.shape[-2:]
rh, rw = region_tokens.shape[-2:]
window_size_h, window_size_w = lh // rh, lw // rw

local_tokens = rearrange(local_tokens, 'b c h w -> b (h w) c')
region_tokens = rearrange(region_tokens, 'b c h w -> b (h w) c')

# calculate local relative positional bias

h_range = torch.arange(window_size_h, device = device)
w_range = torch.arange(window_size_w, device = device)

grid_x, grid_y = torch.meshgrid(h_range, w_range)
grid = torch.stack((grid_x, grid_y))
grid = rearrange(grid, 'c h w -> c (h w)')
grid = (grid[:, :, None] - grid[:, None, :]) + (self.window_size - 1)
bias_indices = (grid * torch.tensor([1, self.window_size * 2 - 1], device = device)[:, None, None]).sum(dim = 0)
rel_pos_bias = self.local_rel_pos_bias(bias_indices)
rel_pos_bias = rearrange(rel_pos_bias, 'i j h -> () h i j')
rel_pos_bias = F.pad(rel_pos_bias, (1, 0, 1, 0), value = 0)

# go through r2l transformer layers

for attn, ff in self.layers:
region_tokens = attn(region_tokens) + region_tokens

# concat region tokens to local tokens

local_tokens = rearrange(local_tokens, 'b (h w) d -> b h w d', h = lh)
local_tokens = rearrange(local_tokens, 'b (h p1) (w p2) d -> (b h w) (p1 p2) d', p1 = window_size_h, p2 = window_size_w)
region_tokens = rearrange(region_tokens, 'b n d -> (b n) () d')

# do self attention on local tokens, along with its regional token

region_and_local_tokens = torch.cat((region_tokens, local_tokens), dim = 1)
region_and_local_tokens = attn(region_and_local_tokens, rel_pos_bias = rel_pos_bias) + region_and_local_tokens

# split back local and regional tokens

region_tokens, local_tokens = region_and_local_tokens[:, :1], region_and_local_tokens[:, 1:]
local_tokens = rearrange(local_tokens, '(b h w) (p1 p2) d -> b (h p1 w p2) d', h = lh // window_size_h, w = lw // window_size_w, p1 = window_size_h)
region_tokens = rearrange(region_tokens, '(b n) () d -> b n d', n = rh * rw)

# feedforwards

local_tokens = ff(local_tokens) + local_tokens
region_tokens = ff(region_tokens) + region_tokens

local_tokens = rearrange(local_tokens, 'b (h w) c -> b c h w', h = lh, w = lw)
region_tokens = rearrange(region_tokens, 'b (h w) c -> b c h w', h = rh, w = rw)
return local_tokens, region_tokens

# classes

class RegionViT(nn.Module):
def __init__(
self,
*,
dim = (64, 128, 256, 512),
depth = (2, 2, 8, 2),
window_size = 7,
num_classes = 1000,
tokenize_local_3_conv = False,
local_patch_size = 4,
use_peg = False,
attn_dropout = 0.,
ff_dropout = 0.,
channels = 3,
):
super().__init__()
dim = cast_tuple(dim, 4)
depth = cast_tuple(depth, 4)
assert len(dim) == 4, 'dim needs to be a single value or a tuple of length 4'
assert len(depth) == 4, 'depth needs to be a single value or a tuple of length 4'

self.local_patch_size = local_patch_size

region_patch_size = local_patch_size * window_size
self.region_patch_size = local_patch_size * window_size

init_dim, *_, last_dim = dim

# local and region encoders

if tokenize_local_3_conv:
self.local_encoder = nn.Sequential(
nn.Conv2d(3, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 2, 1),
nn.LayerNorm(init_dim),
nn.GELU(),
nn.Conv2d(init_dim, init_dim, 3, 1, 1)
)
else:
self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)

self.region_encoder = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = region_patch_size, p2 = region_patch_size),
nn.Conv2d((region_patch_size ** 2) * channels, init_dim, 1)
)

# layers

current_dim = init_dim
self.layers = nn.ModuleList([])

for ind, dim, num_layers in zip(range(4), dim, depth):
not_first = ind != 0
need_downsample = not_first
need_peg = not_first and use_peg

self.layers.append(nn.ModuleList([
Downsample(current_dim, dim) if need_downsample else nn.Identity(),
PEG(dim) if need_peg else nn.Identity(),
R2LTransformer(dim, depth = num_layers, window_size = window_size, attn_dropout = attn_dropout, ff_dropout = ff_dropout)
]))

current_dim = dim

# final logits

self.to_logits = nn.Sequential(
Reduce('b c h w -> b c', 'mean'),
nn.LayerNorm(last_dim),
nn.Linear(last_dim, num_classes)
)

def forward(
self,
x,
return_local_tokens = False
):
*_, h, w = x.shape
assert divisible_by(h, self.region_patch_size) and divisible_by(w, self.region_patch_size), 'height and width must be divisible by region patch size'
assert divisible_by(h, self.local_patch_size) and divisible_by(w, self.local_patch_size), 'height and width must be divisible by local patch size'

local_tokens = self.local_encoder(x)
region_tokens = self.region_encoder(x)

for down, peg, transformer in self.layers:
local_tokens, region_tokens = down(local_tokens), down(region_tokens)
local_tokens = peg(local_tokens)
local_tokens, region_tokens = transformer(local_tokens, region_tokens)

return self.to_logits(region_tokens)

0 comments on commit 06d3753

Please sign in to comment.