Skip to content

Commit

Permalink
fix multiheaded qk rmsnorm in nViT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2024
1 parent 36ddc7a commit e300cdd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.8.4',
version = '1.8.5',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
10 changes: 5 additions & 5 deletions vit_pytorch/normalized_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def __init__(

self.dropout = dropout

self.q_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(dim_inner) * (dim_head ** 0.25))
self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))

self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
Expand All @@ -90,15 +90,15 @@ def forward(
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

q = q * self.q_scale
k = k * self.k_scale

q, k, v = map(self.split_heads, (q, k, v))

# query key rmsnorm

q, k = map(l2norm, (q, k))

q = q * self.q_scale
k = k * self.k_scale

# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16

out = F.scaled_dot_product_attention(
Expand Down

0 comments on commit e300cdd

Please sign in to comment.