Skip to content

Commit

Permalink
address lucidrains#274
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 9, 2023
1 parent 6e2393d commit 3e5d1be
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 54 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.9',
version = '1.4.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
11 changes: 5 additions & 6 deletions vit_pytorch/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Expand All @@ -74,7 +75,7 @@ def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)

class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
Expand All @@ -101,12 +102,10 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.pool = "mean"
self.to_latent = nn.Identity()

self.linear_head = nn.LayerNorm(dim)

def forward(self, img):
device = img.device
Expand Down
8 changes: 3 additions & 5 deletions vit_pytorch/simple_vit_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Expand All @@ -72,7 +73,7 @@ def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)

class SimpleViT(nn.Module):
def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
Expand All @@ -93,10 +94,7 @@ def __init__(self, *, seq_len, patch_size, num_classes, dim, depth, heads, mlp_d
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)

def forward(self, series):
*_, n, dtype = *series.shape, series.dtype
Expand Down
8 changes: 3 additions & 5 deletions vit_pytorch/simple_vit_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Expand All @@ -87,7 +88,7 @@ def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)

class SimpleViT(nn.Module):
def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
Expand All @@ -111,10 +112,7 @@ def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, nu
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)

def forward(self, video):
*_, h, w, dtype = *video.shape, video.dtype
Expand Down
8 changes: 3 additions & 5 deletions vit_pytorch/simple_vit_with_patch_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Expand All @@ -97,7 +98,7 @@ def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)

class SimpleViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, patch_dropout = 0.5):
Expand All @@ -122,10 +123,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.to_latent = nn.Identity()
self.linear_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.linear_head = nn.Linear(dim, num_classes)

def forward(self, img):
*_, h, w, dtype = *img.shape, img.dtype
Expand Down
28 changes: 13 additions & 15 deletions vit_pytorch/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@ def pair(t):

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)

Expand All @@ -41,6 +35,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

Expand All @@ -52,6 +48,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
) if project_out else nn.Identity()

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

Expand All @@ -67,17 +65,20 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x

return self.norm(x)

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
Expand Down Expand Up @@ -107,10 +108,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.mlp_head = nn.Linear(dim, num_classes)

def forward(self, img):
x = self.to_patch_embedding(img)
Expand Down
19 changes: 7 additions & 12 deletions vit_pytorch/vit_with_patch_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,11 @@ def forward(self, x):

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
Expand All @@ -62,6 +55,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

Expand All @@ -73,6 +67,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
) if project_out else nn.Identity()

def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

Expand All @@ -88,15 +83,16 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., patch_merge_layer = None, patch_merge_num_tokens = 8):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])

self.patch_merge_layer_index = default(patch_merge_layer, depth // 2) - 1 # default to mid-way through transformer, as shown in paper
self.patch_merger = PatchMerger(dim = dim, num_tokens_out = patch_merge_num_tokens)

for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for index, (attn, ff) in enumerate(self.layers):
Expand All @@ -106,7 +102,7 @@ def forward(self, x):
if index == self.patch_merge_layer_index:
x = self.patch_merger(x)

return x
return self.norm(x)

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, patch_merge_layer = None, patch_merge_num_tokens = 8, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
Expand All @@ -133,7 +129,6 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml

self.mlp_head = nn.Sequential(
Reduce('b n d -> b d', 'mean'),
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)

Expand Down
8 changes: 3 additions & 5 deletions vit_pytorch/vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def forward(self, x):
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Expand All @@ -80,7 +81,7 @@ def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
return self.norm(x)

class ViT(nn.Module):
def __init__(
Expand Down Expand Up @@ -137,10 +138,7 @@ def __init__(
self.pool = pool
self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
self.mlp_head = nn.Linear(dim, num_classes)

def forward(self, video):
x = self.to_patch_embedding(video)
Expand Down

0 comments on commit 3e5d1be

Please sign in to comment.