Skip to content

Commit

Permalink
allow for mean pool with efficient version too
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 24, 2020
1 parent 2433964 commit 59787a6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.5.0',
version = '0.5.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 7 additions & 3 deletions vit_pytorch/efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from torch import nn

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, channels = 3):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2

Expand All @@ -16,7 +17,8 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, cha
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer

self.to_cls_token = nn.Identity()
self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
Expand All @@ -35,5 +37,7 @@ def forward(self, img):
x += self.pos_embedding[:, :(n + 1)]
x = self.transformer(x)

x = self.to_cls_token(x[:, 0])
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x)

0 comments on commit 59787a6

Please sign in to comment.