forked from lucidrains/vit-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
begin work on NaViT (lucidrains#273)
finish core idea of NaViT
- Loading branch information
1 parent
e9ca1f4
commit 23820bc
Showing
3 changed files
with
352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,305 @@ | ||
from functools import partial | ||
from typing import List | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn, Tensor | ||
from torch.nn.utils.rnn import pad_sequence as orig_pad_sequence | ||
|
||
from einops import rearrange, repeat | ||
from einops.layers.torch import Rearrange | ||
|
||
# helpers | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def default(val, d): | ||
return val if exists(val) else d | ||
|
||
def pair(t): | ||
return t if isinstance(t, tuple) else (t, t) | ||
|
||
def divisible_by(numer, denom): | ||
return (numer % denom) == 0 | ||
|
||
# normalization | ||
# they use layernorm without bias, something that pytorch does not offer | ||
|
||
class LayerNorm(nn.Module): | ||
def __init__(self, dim): | ||
super().__init__() | ||
self.gamma = nn.Parameter(torch.ones(dim)) | ||
self.register_buffer('beta', torch.zeros(dim)) | ||
|
||
def forward(self, x): | ||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) | ||
|
||
# they use a query-key normalization that is equivalent to rms norm (no mean-centering, learned gamma), from vit 22B paper | ||
|
||
class RMSNorm(nn.Module): | ||
def __init__(self, heads, dim): | ||
super().__init__() | ||
self.scale = dim ** 0.5 | ||
self.gamma = nn.Parameter(torch.ones(heads, 1, dim)) | ||
|
||
def forward(self, x): | ||
normed = F.normalize(x, dim = -1) | ||
return normed * self.scale * self.gamma | ||
|
||
# feedforward | ||
|
||
def FeedForward(dim, hidden_dim, dropout = 0.): | ||
return nn.Sequential( | ||
LayerNorm(dim), | ||
nn.Linear(dim, hidden_dim), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim, dim), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): | ||
super().__init__() | ||
inner_dim = dim_head * heads | ||
self.heads = heads | ||
self.norm = LayerNorm(dim) | ||
|
||
self.q_norm = RMSNorm(heads, dim_head) | ||
self.k_norm = RMSNorm(heads, dim_head) | ||
|
||
self.attend = nn.Softmax(dim = -1) | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
self.to_q = nn.Linear(dim, inner_dim, bias = False) | ||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Linear(inner_dim, dim, bias = False), | ||
nn.Dropout(dropout) | ||
) | ||
|
||
def forward( | ||
self, | ||
x, | ||
context = None, | ||
mask = None, | ||
attn_mask = None | ||
): | ||
x = self.norm(x) | ||
kv_input = default(context, x) | ||
|
||
qkv = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)) | ||
|
||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) | ||
|
||
q = self.q_norm(q) | ||
k = self.k_norm(k) | ||
|
||
dots = torch.matmul(q, k.transpose(-1, -2)) | ||
|
||
if exists(mask): | ||
mask = rearrange(mask, 'b j -> b 1 1 j') | ||
dots = dots.masked_fill(~mask, -torch.finfo(dots.dtype).max) | ||
|
||
if exists(attn_mask): | ||
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max) | ||
|
||
attn = self.attend(dots) | ||
attn = self.dropout(attn) | ||
|
||
out = torch.matmul(attn, v) | ||
out = rearrange(out, 'b h n d -> b n (h d)') | ||
return self.to_out(out) | ||
|
||
class Transformer(nn.Module): | ||
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), | ||
FeedForward(dim, mlp_dim, dropout = dropout) | ||
])) | ||
|
||
self.norm = LayerNorm(dim) | ||
|
||
def forward( | ||
self, | ||
x, | ||
mask = None, | ||
attn_mask = None | ||
): | ||
for attn, ff in self.layers: | ||
x = attn(x, mask = mask, attn_mask = attn_mask) + x | ||
x = ff(x) + x | ||
|
||
return self.norm(x) | ||
|
||
class NaViT(nn.Module): | ||
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): | ||
super().__init__() | ||
image_height, image_width = pair(image_size) | ||
|
||
assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.' | ||
|
||
patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size) | ||
patch_dim = channels * (patch_size ** 2) | ||
|
||
self.channels = channels | ||
self.patch_size = patch_size | ||
|
||
self.to_patch_embedding = nn.Sequential( | ||
LayerNorm(patch_dim), | ||
nn.Linear(patch_dim, dim), | ||
LayerNorm(dim), | ||
) | ||
|
||
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim)) | ||
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim)) | ||
|
||
self.dropout = nn.Dropout(emb_dropout) | ||
|
||
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) | ||
|
||
# final attention pooling queries | ||
|
||
self.attn_pool_queries = nn.Parameter(torch.randn(dim)) | ||
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads) | ||
|
||
# output to logits | ||
|
||
self.to_latent = nn.Identity() | ||
|
||
self.mlp_head = nn.Sequential( | ||
LayerNorm(dim), | ||
nn.Linear(dim, num_classes, bias = False) | ||
) | ||
|
||
@property | ||
def device(self): | ||
return next(self.parameters()).device | ||
|
||
def forward( | ||
self, | ||
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly | ||
): | ||
p, c, device = self.patch_size, self.channels, self.device | ||
|
||
arange = partial(torch.arange, device = device) | ||
pad_sequence = partial(orig_pad_sequence, batch_first = True) | ||
|
||
# process images into variable lengthed sequences with attention mask | ||
|
||
num_images = [] | ||
batched_sequences = [] | ||
batched_positions = [] | ||
batched_image_ids = [] | ||
|
||
for images in batched_images: | ||
num_images.append(len(images)) | ||
|
||
sequences = [] | ||
positions = [] | ||
image_ids = torch.empty((0,), device = device, dtype = torch.long) | ||
|
||
for image_id, image in enumerate(images): | ||
assert image.ndim ==3 and image.shape[0] == c | ||
image_dims = image.shape[-2:] | ||
assert all([divisible_by(dim, p) for dim in image_dims]), f'height and width {image_dims} of images must be divisible by patch size {p}' | ||
|
||
ph, pw = map(lambda dim: dim // p, image_dims) | ||
|
||
pos = torch.stack(torch.meshgrid(( | ||
arange(ph), | ||
arange(pw) | ||
), indexing = 'ij'), dim = -1) | ||
|
||
pos = rearrange(pos, 'h w c -> (h w) c') | ||
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p) | ||
|
||
image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id) | ||
sequences.append(seq) | ||
positions.append(pos) | ||
|
||
batched_image_ids.append(image_ids) | ||
batched_sequences.append(torch.cat(sequences, dim = 0)) | ||
batched_positions.append(torch.cat(positions, dim = 0)) | ||
|
||
# derive key padding mask | ||
|
||
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long) | ||
max_length = arange(lengths.amax().item()) | ||
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n') | ||
|
||
# derive attention mask, and combine with key padding mask from above | ||
|
||
batched_image_ids = pad_sequence(batched_image_ids) | ||
attn_mask = rearrange(batched_image_ids, 'b i -> b 1 i 1') == rearrange(batched_image_ids, 'b j -> b 1 1 j') | ||
attn_mask = attn_mask & rearrange(key_pad_mask, 'b j -> b 1 1 j') | ||
|
||
# combine patched images as well as the patched width / height positions for 2d positional embedding | ||
|
||
patches = pad_sequence(batched_sequences) | ||
patch_positions = pad_sequence(batched_positions) | ||
|
||
# need to know how many images for final attention pooling | ||
|
||
num_images = torch.tensor(num_images, device = device, dtype = torch.long) | ||
|
||
# to patches | ||
|
||
x = self.to_patch_embedding(patches) | ||
|
||
# factorized 2d absolute positional embedding | ||
|
||
h_indices, w_indices = patch_positions.unbind(dim = -1) | ||
|
||
h_pos = self.pos_embed_height[h_indices] | ||
w_pos = self.pos_embed_width[w_indices] | ||
|
||
x = x + h_pos + w_pos | ||
|
||
# embed dropout | ||
|
||
x = self.dropout(x) | ||
|
||
# attention | ||
|
||
x = self.transformer(x, attn_mask = attn_mask) | ||
|
||
# do attention pooling at the end | ||
|
||
max_queries = num_images.amax().item() | ||
|
||
queries = repeat(self.attn_pool_queries, 'd -> b n d', n = max_queries, b = x.shape[0]) | ||
|
||
# attention pool mask | ||
|
||
image_id_arange = arange(max_queries) | ||
|
||
attn_pool_mask = rearrange(image_id_arange, 'i -> i 1') == rearrange(batched_image_ids, 'b j -> b 1 j') | ||
|
||
attn_pool_mask = attn_pool_mask & rearrange(key_pad_mask, 'b j -> b 1 j') | ||
|
||
attn_pool_mask = rearrange(attn_pool_mask, 'b i j -> b 1 i j') | ||
|
||
# attention pool | ||
|
||
x = self.attn_pool(queries, context = x, attn_mask = attn_pool_mask) + queries | ||
|
||
x = rearrange(x, 'b n d -> (b n) d') | ||
|
||
# each batch element may not have same amount of images | ||
|
||
is_images = image_id_arange < rearrange(num_images, 'b -> b 1') | ||
is_images = rearrange(is_images, 'b n -> (b n)') | ||
|
||
x = x[is_images] | ||
|
||
# project out to logits | ||
|
||
x = self.to_latent(x) | ||
|
||
return self.mlp_head(x) |