Skip to content

Commit

Permalink
make value residual learned
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 24, 2024
1 parent 24196a3 commit 56373c0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.8',
version = '1.8.9',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
long_description = long_description,
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = '[email protected]',
Expand Down
16 changes: 12 additions & 4 deletions vit_pytorch/simple_vit_with_value_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def FeedForward(dim, hidden_dim):
)

class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64):
def __init__(self, dim, heads = 8, dim_head = 64, learned_value_residual_mix = False):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
Expand All @@ -50,14 +50,21 @@ def __init__(self, dim, heads = 8, dim_head = 64):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

self.to_residual_mix = nn.Sequential(
nn.Linear(dim, heads),
nn.Sigmoid(),
Rearrange('b n h -> b h n 1')
) if learned_value_residual_mix else (lambda _: 0.5)

def forward(self, x, value_residual = None):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

if exists(value_residual):
v = 0.5 * (v + value_residual)
mix = self.to_residual_mix(x)
v = v * mix + value_residual * (1. - mix)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

Expand All @@ -73,9 +80,10 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = ModuleList([])
for _ in range(depth):
for i in range(depth):
is_first = i == 0
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head),
Attention(dim, heads = heads, dim_head = dim_head, learned_value_residual_mix = not is_first),
FeedForward(dim, mlp_dim)
]))
def forward(self, x):
Expand Down

0 comments on commit 56373c0

Please sign in to comment.