Skip to content

Commit

Permalink
make sure global average pool can be used for vivit in place of cls t…
Browse files Browse the repository at this point in the history
…oken
  • Loading branch information
lucidrains committed Oct 25, 2022
1 parent 13fabf9 commit 6ec8fda
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 21 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.37.0',
version = '0.37.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
6 changes: 3 additions & 3 deletions vit_pytorch/simple_vit_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
nn.Linear(dim, num_classes)
)

def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype

x = self.to_patch_embedding(img)
x = self.to_patch_embedding(video)
pe = posemb_sincos_3d(x)
x = rearrange(x, 'b ... d -> b (...) d') + pe

Expand Down
4 changes: 2 additions & 2 deletions vit_pytorch/vit_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
nn.Linear(dim, num_classes)
)

def forward(self, img):
x = self.to_patch_embedding(img)
def forward(self, video):
x = self.to_patch_embedding(video)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
Expand Down
44 changes: 29 additions & 15 deletions vit_pytorch/vivit.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
from torch import nn

from einops import rearrange, repeat
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

# helpers

def exists(val):
return val is not None

def pair(t):
return t if isinstance(t, tuple) else (t, t)

Expand Down Expand Up @@ -106,20 +109,25 @@ def __init__(
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'

num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
num_image_patches = (image_height // patch_height) * (image_width // patch_width)
num_frame_patches = (frames // frame_patch_size)

patch_dim = channels * patch_height * patch_width * frame_patch_size

assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

self.global_average_pool = pool == 'mean'

self.to_patch_embedding = nn.Sequential(
Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
nn.Linear(patch_dim, dim),
)

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim))
self.dropout = nn.Dropout(emb_dropout)
self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim))

self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None
self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None

self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout)
self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout)
Expand All @@ -132,13 +140,16 @@ def __init__(
nn.Linear(dim, num_classes)
)

def forward(self, img):
x = self.to_patch_embedding(img)
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape

spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)
x += self.pos_embedding[:, :(n + 1)]
x = x + self.pos_embedding

if exists(self.spatial_cls_token):
spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f)
x = torch.cat((spatial_cls_tokens, x), dim = 2)

x = self.dropout(x)

x = rearrange(x, 'b f n d -> (b f) n d')
Expand All @@ -149,21 +160,24 @@ def forward(self, img):

x = rearrange(x, '(b f) n d -> b f n d', b = b)

# excise out the spatial cls tokens for temporal attention
# excise out the spatial cls tokens or average pool for temporal attention

x = x[:, :, 0]
x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean')

# append temporal CLS tokens

temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)
if exists(self.temporal_cls_token):
temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b)

x = torch.cat((temporal_cls_tokens, x), dim = 1)
x = torch.cat((temporal_cls_tokens, x), dim = 1)

# attend across time

x = self.temporal_transformer(x)

x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
# excise out temporal cls token or average pool

x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean')

x = self.to_latent(x)
return self.mlp_head(x)

0 comments on commit 6ec8fda

Please sign in to comment.