Skip to content

Commit

Permalink
add NesT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 28, 2021
1 parent b483b16 commit daf3abb
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 1 deletion.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,32 @@ img = torch.randn(1, 3, 224, 224)
pred = model(img) # (1, 1000)
```

## NesT

<img src="./images/nest.png" width="400px"></img>

This <a href="https://arxiv.org/abs/2105.12723">paper</a> decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.

You can use it with the following code (ex. NesT-T)

```python
import torch
from vit_pytorch.nest import NesT

nest = NesT(
image_size = 224,
patch_size = 4,
dim = 96,
heads = 3,
num_heirarchies = 3, # number of heirarchies
block_repeats = (8, 4, 1), # the number of transformer blocks at each heirarchy, starting from the bottom
num_classes = 1000
)

img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```

## Masked Patch Prediction

Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
Expand Down Expand Up @@ -787,6 +813,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{zhang2021aggregating,
title = {Aggregating Nested Transformers},
author = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
year = {2021},
eprint = {2105.12723},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{caron2021emerging,
title = {Emerging Properties in Self-Supervised Vision Transformers},
Expand Down
Binary file added images/nest.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.18.4',
version = '0.19.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
177 changes: 177 additions & 0 deletions vit_pytorch/nest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import torch
from torch import nn, einsum

from einops import rearrange
from einops.layers.torch import Rearrange, Reduce

# helpers

def cast_tuple(val, depth):
return val if isinstance(val, tuple) else ((val,) * depth)

# classes

class ChanNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

def forward(self, x):
std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt()
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (std + self.eps) * self.g + self.b

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ChanNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, mlp_mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mlp_mult, 1),
nn.GELU(),
nn.Dropout(dropout),
nn.Conv2d(dim * mlp_mult, dim, 1),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dropout = 0.):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
dim_head = dim // heads
self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1, bias = False)

self.to_out = nn.Sequential(
nn.Conv2d(dim, dim, 1),
nn.Dropout(dropout)
)

def forward(self, x):
b, c, h, w, heads = *x.shape, self.heads

qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), qkv)

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = self.attend(dots)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)

def Aggregate(dim, dim_out):
return nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding = 1),
ChanNorm(dim_out),
nn.MaxPool2d(2)
)

class Transformer(nn.Module):
def __init__(self, dim, seq_len, depth, heads, mlp_mult, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = nn.Parameter(torch.randn(seq_len))

for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
]))
def forward(self, x):
*_, h, w = x.shape

pos_emb = self.pos_emb[:(h * w)]
pos_emb = rearrange(pos_emb, '(h w) -> () () h w', h = h, w = w)
x = x + pos_emb

for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x

class NesT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
heads,
num_heirarchies,
block_repeats,
mlp_mult = 4,
channels = 3,
dim_head = 64,
dropout = 0.
):
super().__init__()
assert (image_size % patch_size) == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
fmap_size = image_size // patch_size
blocks = 2 ** (num_heirarchies - 1)

seq_len = (fmap_size // blocks) ** 2 # sequence length is held constant across heirarchy
mults = [2 ** i for i in reversed(range(num_heirarchies))]

layer_heads = list(map(lambda t: t * heads, mults))
layer_dims = list(map(lambda t: t * dim, mults))

layer_dims = [*layer_dims, layer_dims[-1]]
dim_pairs = zip(layer_dims[:-1], layer_dims[1:])

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1 = patch_size, p2 = patch_size),
nn.Conv2d(patch_dim, layer_dims[0], 1),
)

block_repeats = cast_tuple(block_repeats, num_heirarchies)

self.layers = nn.ModuleList([])

for level, heads, (dim_in, dim_out), block_repeat in zip(reversed(range(num_heirarchies)), layer_heads, dim_pairs, block_repeats):
is_last = level == 0
depth = block_repeat

self.layers.append(nn.ModuleList([
Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout),
Aggregate(dim_in, dim_out) if not is_last else nn.Identity()
]))

self.mlp_head = nn.Sequential(
ChanNorm(dim),
Reduce('b c h w -> b c', 'mean'),
nn.Linear(dim, num_classes)
)

def forward(self, img):
x = self.to_patch_embedding(img)
b, c, h, w = x.shape

num_heirarchies = len(self.layers)

for level, (transformer, aggregate) in zip(reversed(range(num_heirarchies)), self.layers):
block_size = 2 ** level
x = rearrange(x, 'b c (b1 h) (b2 w) -> (b b1 b2) c h w', b1 = block_size, b2 = block_size)
x = transformer(x)
x = rearrange(x, '(b b1 b2) c h w -> b c (b1 h) (b2 w)', b1 = block_size, b2 = block_size)
x = aggregate(x)

return self.mlp_head(x)

0 comments on commit daf3abb

Please sign in to comment.