Skip to content

Commit

Permalink
add SimMIM
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 21, 2021
1 parent c5a4616 commit 5ae5557
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 1 deletion.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- [RegionViT](#regionvit)
- [NesT](#nest)
- [Masked Autoencoder](#masked-autoencoder)
- [Simple Masked Image Modeling](#simple-masked-image-modeling)
- [Masked Patch Prediction](#masked-patch-prediction)
- [Dino](#dino)
- [Accessing Attention](#accessing-attention)
Expand Down Expand Up @@ -519,6 +520,46 @@ img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```

## Simple Masked Image Modeling

<img src="./images/simmim.png" width="400px"/>

This <a href="https://arxiv.org/abs/2111.09886">paper</a> proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.

You can use this as follows

```python
import torch
from vit_pytorch import ViT
from vit_pytorch.simmim import SimMIM

v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)

mim = SimMIM(
encoder = v,
masking_ratio = 0.5 # they found 50% to yield the best results
)

images = torch.randn(8, 3, 256, 256)

loss = mim(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

torch.save(v.state_dict(), './trained-vit.pt')
```


## Masked Autoencoder

<img src="./images/mae.png" width="400px"/>
Expand Down Expand Up @@ -1026,6 +1067,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{xie2021simmim,
title = {SimMIM: A Simple Framework for Masked Image Modeling},
author = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
year = {2021},
eprint = {2111.09886},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
Binary file added images/simmim.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.22.0',
version = '0.23.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
84 changes: 84 additions & 0 deletions vit_pytorch/simmim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat

class SimMIM(nn.Module):
def __init__(
self,
*,
encoder,
masking_ratio = 0.5
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio

# extract some hyperparameters and functions from encoder (vision transformer to be trained)

self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]

# simple linear head

self.mask_token = nn.Parameter(torch.randn(encoder_dim))
self.to_pixels = nn.Linear(encoder_dim, pixel_values_per_patch)

def forward(self, img):
device = img.device

# get patches

patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape

# for indexing purposes

batch_range = torch.arange(batch, device = device)[:, None]

# get positions

pos_emb = self.encoder.pos_embedding[:, 1:(num_patches + 1)]

# patch to encoder tokens and add positions

tokens = self.patch_to_emb(patches)
tokens = tokens + pos_emb

# prepare mask tokens

mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_patches)
mask_tokens = mask_tokens + pos_emb

# calculate of patches needed to be masked, and get positions (indices) to be masked

num_masked = int(self.masking_ratio * num_patches)
masked_indices = torch.rand(batch, num_patches, device = device).topk(k = num_masked, dim = -1).indices
masked_bool_mask = torch.zeros((batch, num_patches), device = device).scatter_(-1, masked_indices, 1).bool()

# mask tokens

tokens = torch.where(masked_bool_mask[..., None], mask_tokens, tokens)

# attend with vision transformer

encoded = self.encoder.transformer(tokens)

# get the masked tokens

encoded_mask_tokens = encoded[batch_range, masked_indices]

# small linear projection for predicted pixel values

pred_pixel_values = self.to_pixels(encoded_mask_tokens)

# get the masked patches for the final reconstruction loss

masked_patches = patches[batch_range, masked_indices]

# calculate reconstruction loss

recon_loss = F.l1_loss(pred_pixel_values, masked_patches) / num_masked
return recon_loss

0 comments on commit 5ae5557

Please sign in to comment.