Skip to content

Commit

Permalink
release masked autoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 13, 2021
1 parent cb1729a commit e8f6d72
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 1 deletion.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,45 @@ img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)
```

## Masked Autoencoder

<img src="./images/mae.png" width="400px"/>

A new <a href="https://arxiv.org/abs/2111.06377">Kaiming He paper</a> proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.

You can use it with the following code

```python
import torch
from vit_pytorch import ViT, MAE

v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)

mae = MAE(
encoder = v,
masking_ratio = 0.75,
decoder_dim = 1024,
decoder_depth = 6,
decoder_heads = 8
)

images = torch.randn(8, 3, 256, 256)

loss = mae(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn
```

## Masked Patch Prediction

Thanks to <a href="https://github.com/zankner">Zach</a>, you can train using the original masked patch prediction task presented in the paper, with the following code.
Expand Down Expand Up @@ -943,6 +982,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
}
```

```bibtex
@misc{he2021masked,
title = {Masked Autoencoders Are Scalable Vision Learners},
author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
year = {2021},
eprint = {2111.06377},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
Binary file added images/mae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.21.1',
version = '0.22.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
1 change: 1 addition & 0 deletions vit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from vit_pytorch.vit import ViT
from vit_pytorch.mae import MAE
from vit_pytorch.dino import Dino
93 changes: 93 additions & 0 deletions vit_pytorch/mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch
from math import ceil
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat

from vit_pytorch.vit import Transformer

class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio

# extract some hyperparameters and functions from encoder (vision transformer to be trained)

self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1]

# decoder parameters

self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)

def forward(self, img):
device = img.device

# get patches

patches = self.to_patch(img)
batch, num_patches, *_ = patches.shape

# patch to encoder tokens and add positions

tokens = self.patch_to_emb(patches)
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)]

# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked

num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

# get the unmasked tokens to be encoded

batch_range = torch.arange(batch, device = device)[:, None]
tokens = tokens[batch_range, unmasked_indices]

# get the patches to be masked for the final reconstruction loss

masked_patches = patches[batch_range, masked_indices]

# attend with vision transformer

encoded_tokens = self.encoder.transformer(tokens)

# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder

decoder_tokens = self.enc_to_dec(encoded_tokens)

# repeat mask tokens for number of masked, and add the positions using the masked indices derived above

mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

# concat the masked tokens to the decoder tokens and attend with decoder

decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1)
decoded_tokens = self.decoder(decoder_tokens)

# splice out the mask tokens and project to pixel values

mask_tokens = decoded_tokens[:, -num_masked:]
pred_pixel_values = self.to_pixels(mask_tokens)

# calculate reconstruction loss

recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss

0 comments on commit e8f6d72

Please sign in to comment.