Skip to content

Commit

Permalink
Add Masked Position Prediction (lucidrains#260)
Browse files Browse the repository at this point in the history
* Create mp3.py

* Implementation: Position Prediction as an Effective Pretraining Strategy

* Added description for Masked Position Prediction

* MP3 image added
  • Loading branch information
Vishu26 authored Mar 7, 2023
1 parent f621c2b commit 4218556
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Masked Position Prediction](#masked-position-prediction)
- [Adaptive Token Sampling](#adaptive-token-sampling)
- [Patch Merger](#patch-merger)
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
Expand Down Expand Up @@ -844,6 +845,39 @@ for _ in range(100):
torch.save(model.state_dict(), './pretrained-net.pt')
```

## Masked Position Prediction

<img src="./images/mp3.png" width="400px"></img>

New <a href="https://arxiv.org/abs/2207.07611">paper</a> that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.

```python
import torch
from vit_pytorch.mp3 import MP3

model = MP3(
image_size=256,
patch_size=8,
masking_ratio=0.75
dim=1024,
depth=6,
heads=8,
mlp_dim=2048,
dropout=0.1,
)

images = torch.randn(8, 3, 256, 256)

loss = model(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')
```

## Adaptive Token Sampling

<img src="./images/ats.png" width="400px"></img>
Expand Down
Binary file added images/mp3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
137 changes: 137 additions & 0 deletions vit_pytorch/mp3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

# pre-layernorm

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

# cross attention

class CrossAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, context):
b, n, _, h = *x.shape, self.heads

qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x, context):
for attn, ff in self.layers:
x = attn(x, context=context) + x
x = ff(x) + x
return x

# Masked Position Prediction Pre-Training

class MP3(nn.Module):
def __init__(self, *, image_size, patch_size, masking_ratio, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio

num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_patches)
)
self.out = nn.Softmax(dim = -1)

def forward(self, img):
device = img.device
tokens = self.to_patch_embedding(img)
batch, num_patches, *_ = tokens.shape

# Masking
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

batch_range = torch.arange(batch, device = device)[:, None]
tokens_unmasked = tokens[batch_range, unmasked_indices]

x = rearrange(self.mlp_head(self.transformer(tokens, tokens_unmasked)), 'b n d -> (b n) d')
x = self.out(x)

# Define labels
labels = repeat(torch.arange(num_patches, device = device), 'n -> b n', b = batch).flatten()
loss = F.cross_entropy(x, labels)

return loss

0 comments on commit 4218556

Please sign in to comment.