Skip to content

Commit

Permalink
remove patch size from T2TViT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 22, 2021
1 parent 6af7bbc commit 3744ac6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ from vit_pytorch.t2t import T2TViT
v = T2TViT(
dim = 512,
image_size = 224,
patch_size = 16,
depth = 5,
heads = 8,
mlp_dim = 512,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.1',
version = '0.7.2',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 8 additions & 6 deletions vit_pytorch/t2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

# classes

def conv_output_size(image_size, kernel_size, stride, padding):
return int(((image_size - kernel_size + (2 * padding)) / stride) + 1)

class RearrangeImage(nn.Module):
def forward(self, x):
return rearrange(x, 'b (h w) c -> b c h w', h = int(math.sqrt(x.shape[1])))
Expand All @@ -17,19 +20,18 @@ def forward(self, x):

class T2TViT(nn.Module):
def __init__(
self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
self, *, image_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

layers = []
layer_dim = channels
output_image_size = image_size

for i, (kernel_size, stride) in enumerate(t2t_layers):
layer_dim *= kernel_size ** 2
is_first = i == 0
output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2)

layers.extend([
RearrangeImage() if not is_first else nn.Identity(),
Expand All @@ -41,7 +43,7 @@ def __init__(
layers.append(nn.Linear(layer_dim, dim))
self.to_patch_embedding = nn.Sequential(*layers)

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.pos_embedding = nn.Parameter(torch.randn(1, output_image_size ** 2 + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

Expand All @@ -61,7 +63,7 @@ def forward(self, img):

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x += self.pos_embedding
x = self.dropout(x)

x = self.transformer(x)
Expand Down

0 comments on commit 3744ac6

Please sign in to comment.