diff --git a/setup.py b/setup.py index 31f11e4..f3dc68f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.37.0', + version = '0.37.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description_content_type = 'text/markdown', diff --git a/vit_pytorch/simple_vit_3d.py b/vit_pytorch/simple_vit_3d.py index c189fe6..edbb67f 100644 --- a/vit_pytorch/simple_vit_3d.py +++ b/vit_pytorch/simple_vit_3d.py @@ -114,10 +114,10 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu nn.Linear(dim, num_classes) ) - def forward(self, img): - *_, h, w, dtype = *img.shape, img.dtype + def forward(self, video): + *_, h, w, dtype = *video.shape, video.dtype - x = self.to_patch_embedding(img) + x = self.to_patch_embedding(video) pe = posemb_sincos_3d(x) x = rearrange(x, 'b ... d -> b (...) d') + pe diff --git a/vit_pytorch/vit_3d.py b/vit_pytorch/vit_3d.py index 7fd8cdd..fd269f9 100644 --- a/vit_pytorch/vit_3d.py +++ b/vit_pytorch/vit_3d.py @@ -112,8 +112,8 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu nn.Linear(dim, num_classes) ) - def forward(self, img): - x = self.to_patch_embedding(img) + def forward(self, video): + x = self.to_patch_embedding(video) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) diff --git a/vit_pytorch/vivit.py b/vit_pytorch/vivit.py index 67e8c52..1acb6f7 100644 --- a/vit_pytorch/vivit.py +++ b/vit_pytorch/vivit.py @@ -1,11 +1,14 @@ import torch from torch import nn -from einops import rearrange, repeat +from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange # helpers +def exists(val): + return val is not None + def pair(t): return t if isinstance(t, tuple) else (t, t) @@ -106,20 +109,25 @@ def __init__( assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' - num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) + num_image_patches = (image_height // patch_height) * (image_width // patch_width) + num_frame_patches = (frames // frame_patch_size) + patch_dim = channels * patch_height * patch_width * frame_patch_size assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + self.global_average_pool = pool == 'mean' + self.to_patch_embedding = nn.Sequential( Rearrange('b c (f pf) (h p1) (w p2) -> b f (h w) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), nn.Linear(patch_dim, dim), ) - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.pos_embedding = nn.Parameter(torch.randn(1, num_frame_patches, num_image_patches, dim)) self.dropout = nn.Dropout(emb_dropout) - self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) + + self.spatial_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None + self.temporal_cls_token = nn.Parameter(torch.randn(1, 1, dim)) if not self.global_average_pool else None self.spatial_transformer = Transformer(dim, spatial_depth, heads, dim_head, mlp_dim, dropout) self.temporal_transformer = Transformer(dim, temporal_depth, heads, dim_head, mlp_dim, dropout) @@ -132,13 +140,16 @@ def __init__( nn.Linear(dim, num_classes) ) - def forward(self, img): - x = self.to_patch_embedding(img) + def forward(self, video): + x = self.to_patch_embedding(video) b, f, n, _ = x.shape - spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f) - x = torch.cat((spatial_cls_tokens, x), dim = 2) - x += self.pos_embedding[:, :(n + 1)] + x = x + self.pos_embedding + + if exists(self.spatial_cls_token): + spatial_cls_tokens = repeat(self.spatial_cls_token, '1 1 d -> b f 1 d', b = b, f = f) + x = torch.cat((spatial_cls_tokens, x), dim = 2) + x = self.dropout(x) x = rearrange(x, 'b f n d -> (b f) n d') @@ -149,21 +160,24 @@ def forward(self, img): x = rearrange(x, '(b f) n d -> b f n d', b = b) - # excise out the spatial cls tokens for temporal attention + # excise out the spatial cls tokens or average pool for temporal attention - x = x[:, :, 0] + x = x[:, :, 0] if not self.global_average_pool else reduce(x, 'b f n d -> b f d', 'mean') # append temporal CLS tokens - temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b) + if exists(self.temporal_cls_token): + temporal_cls_tokens = repeat(self.temporal_cls_token, '1 1 d-> b 1 d', b = b) - x = torch.cat((temporal_cls_tokens, x), dim = 1) + x = torch.cat((temporal_cls_tokens, x), dim = 1) # attend across time x = self.temporal_transformer(x) - x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] + # excise out temporal cls token or average pool + + x = x[:, 0] if not self.global_average_pool else reduce(x, 'b f d -> b d', 'mean') x = self.to_latent(x) return self.mlp_head(x)