forked from lucidrains/vit-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cb1729a
commit e8f6d72
Showing
5 changed files
with
145 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from vit_pytorch.vit import ViT | ||
from vit_pytorch.mae import MAE | ||
from vit_pytorch.dino import Dino |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import torch | ||
from math import ceil | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from einops import rearrange, repeat | ||
|
||
from vit_pytorch.vit import Transformer | ||
|
||
class MAE(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
encoder, | ||
decoder_dim, | ||
masking_ratio = 0.75, | ||
decoder_depth = 1, | ||
decoder_heads = 8, | ||
decoder_dim_head = 64 | ||
): | ||
super().__init__() | ||
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' | ||
self.masking_ratio = masking_ratio | ||
|
||
# extract some hyperparameters and functions from encoder (vision transformer to be trained) | ||
|
||
self.encoder = encoder | ||
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] | ||
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2] | ||
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] | ||
|
||
# decoder parameters | ||
|
||
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity() | ||
self.mask_token = nn.Parameter(torch.randn(decoder_dim)) | ||
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4) | ||
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim) | ||
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch) | ||
|
||
def forward(self, img): | ||
device = img.device | ||
|
||
# get patches | ||
|
||
patches = self.to_patch(img) | ||
batch, num_patches, *_ = patches.shape | ||
|
||
# patch to encoder tokens and add positions | ||
|
||
tokens = self.patch_to_emb(patches) | ||
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] | ||
|
||
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked | ||
|
||
num_masked = int(self.masking_ratio * num_patches) | ||
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1) | ||
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:] | ||
|
||
# get the unmasked tokens to be encoded | ||
|
||
batch_range = torch.arange(batch, device = device)[:, None] | ||
tokens = tokens[batch_range, unmasked_indices] | ||
|
||
# get the patches to be masked for the final reconstruction loss | ||
|
||
masked_patches = patches[batch_range, masked_indices] | ||
|
||
# attend with vision transformer | ||
|
||
encoded_tokens = self.encoder.transformer(tokens) | ||
|
||
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder | ||
|
||
decoder_tokens = self.enc_to_dec(encoded_tokens) | ||
|
||
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above | ||
|
||
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked) | ||
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) | ||
|
||
# concat the masked tokens to the decoder tokens and attend with decoder | ||
|
||
decoder_tokens = torch.cat((decoder_tokens, mask_tokens), dim = 1) | ||
decoded_tokens = self.decoder(decoder_tokens) | ||
|
||
# splice out the mask tokens and project to pixel values | ||
|
||
mask_tokens = decoded_tokens[:, -num_masked:] | ||
pred_pixel_values = self.to_pixels(mask_tokens) | ||
|
||
# calculate reconstruction loss | ||
|
||
recon_loss = F.mse_loss(pred_pixel_values, masked_patches) | ||
return recon_loss |