Skip to content

Commit

Permalink
correct need for post-attention dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 30, 2022
1 parent 6d7298d commit 4e6a42a
Show file tree
Hide file tree
Showing 20 changed files with 61 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.29.1',
version = '0.30.0',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/ats_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.output_num_tokens = output_num_tokens
Expand All @@ -163,6 +165,7 @@ def forward(self, x, *, mask):
dots = dots.masked_fill(~dots_mask, mask_value)

attn = self.attend(dots)
attn = self.dropout(attn)

sampled_token_ids = None

Expand Down
4 changes: 4 additions & 0 deletions vit_pytorch/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.mix_heads_pre_attn = nn.Parameter(torch.randn(heads, heads))
self.mix_heads_post_attn = nn.Parameter(torch.randn(heads, heads))
Expand All @@ -96,7 +97,10 @@ def forward(self, x, context = None):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

dots = einsum('b h i j, h g -> b g i j', dots, self.mix_heads_pre_attn) # talking heads, pre-softmax

attn = self.attend(dots)
attn = self.dropout(attn)

attn = einsum('b h i j, h g -> b g i j', attn, self.mix_heads_post_attn) # talking heads, post-softmax

out = einsum('b h i j, b h j d -> b h i d', attn, v)
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/cross_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

Expand All @@ -69,6 +71,7 @@ def forward(self, x, context = None, kv_include_self = False):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down
4 changes: 4 additions & 0 deletions vit_pytorch/crossformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def __init__(
self.window_size = window_size

self.norm = LayerNorm(dim)

self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1)

Expand Down Expand Up @@ -151,6 +154,7 @@ def forward(self, x):
# attend

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

# merge heads

Expand Down
2 changes: 2 additions & 0 deletions vit_pytorch/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)
Expand All @@ -94,6 +95,7 @@ def forward(self, x):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/deepvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.dropout = nn.Dropout(dropout)

self.reattn_weights = nn.Parameter(torch.randn(heads, heads))

self.reattn_norm = nn.Sequential(
Expand All @@ -64,6 +66,7 @@ def forward(self, x):

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)

# re-attention

Expand Down
2 changes: 2 additions & 0 deletions vit_pytorch/levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, drop
self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)
Expand Down Expand Up @@ -100,6 +101,7 @@ def forward(self, x):
dots = self.apply_pos_bias(dots)

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
Expand Down
2 changes: 2 additions & 0 deletions vit_pytorch/local_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -93,6 +94,7 @@ def forward(self, x):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down
5 changes: 5 additions & 0 deletions vit_pytorch/mobile_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

self.to_out = nn.Sequential(
Expand All @@ -67,7 +69,10 @@ def forward(self, x):
t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b p h n d -> b p n (h d)')
return self.to_out(out)
Expand Down
2 changes: 2 additions & 0 deletions vit_pytorch/nest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, dim, heads = 8, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -71,6 +72,7 @@ def forward(self, x):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/parallel_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -64,6 +66,7 @@ def forward(self, x):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down
2 changes: 2 additions & 0 deletions vit_pytorch/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -63,6 +64,7 @@ def forward(self, x):
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down
8 changes: 7 additions & 1 deletion vit_pytorch/regionvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ def __init__(
inner_dim = dim_head * heads

self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, rel_pos_bias = None):
h = self.heads
Expand All @@ -86,6 +91,7 @@ def forward(self, x, rel_pos_bias = None):
sim = sim + rel_pos_bias

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

# merge heads

Expand Down
2 changes: 2 additions & 0 deletions vit_pytorch/rvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., use_rotary = Tru
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.use_ds_conv = use_ds_conv

Expand Down Expand Up @@ -148,6 +149,7 @@ def forward(self, x, pos_emb, fmap_dims):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
Expand Down
4 changes: 4 additions & 0 deletions vit_pytorch/scalable_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
self.heads = heads
self.scale = dim_key ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias = False)
self.to_k = nn.Conv2d(dim, dim_key * heads, reduction_factor, stride = reduction_factor, bias = False)
Expand All @@ -116,6 +117,7 @@ def forward(self, x):
# attention

attn = self.attend(dots)
attn = self.dropout(attn)

# aggregate values

Expand All @@ -141,6 +143,7 @@ def __init__(
self.scale = dim_key ** -0.5
self.window_size = window_size
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.local_interactive_module = nn.Conv2d(dim_value * heads, dim_value * heads, 3, padding = 1)

Expand Down Expand Up @@ -176,6 +179,7 @@ def forward(self, x):
# attention

attn = self.attend(dots)
attn = self.dropout(attn)

# aggregate values

Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/twins_svt.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., k = 7):
self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
self.to_kv = nn.Conv2d(dim, inner_dim * 2, k, stride = k, bias = False)

self.dropout = nn.Dropout(dropout)

self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1),
nn.Dropout(dropout)
Expand All @@ -145,6 +147,7 @@ def forward(self, x):
dots = einsum('b i d, b j d -> b i j', q, k) * self.scale

attn = dots.softmax(dim = -1)
attn = self.dropout(attn)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -56,6 +58,7 @@ def forward(self, x):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/vit_for_small_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.temperature = nn.Parameter(torch.log(torch.tensor(dim_head ** -0.5)))

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -60,6 +62,7 @@ def forward(self, x):
dots = dots.masked_fill(mask, mask_value)

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down
3 changes: 3 additions & 0 deletions vit_pytorch/vit_with_patch_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_out = nn.Sequential(
Expand All @@ -77,6 +79,7 @@ def forward(self, x):
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)
attn = self.dropout(attn)

out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
Expand Down

0 comments on commit 4e6a42a

Please sign in to comment.