Skip to content

Commit

Permalink
simplify cosine sim attention based on my own explorations and experi…
Browse files Browse the repository at this point in the history
…mentations
  • Loading branch information
lucidrains committed Jul 11, 2022
1 parent 7b0f48b commit e1a9bfc
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,8 @@ I have validated that this works just as well as dot product attention in an aut

This flavor of attention also has <a href="https://arxiv.org/abs/2111.05498">a connection</a> to sparse distributed memory. <a href="https://www.youtube.com/watch?v=THIIk7LR9_8">[youtube talk]</a>

Update: In my own experiments, simply bounding the scale from a range from 0 to 20 using a sigmoid performed better.

You can use it as follows

```python
Expand All @@ -1009,8 +1011,7 @@ model = TransformerWrapper(
dim = 512,
depth = 6,
heads = 8,
use_qk_norm_attn = True, # set this to True
qk_norm_attn_seq_len = 1024 # set this to max_seq_len from above
use_qk_norm_attn = True # set this to True
)
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.29.2',
version = '0.30.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
12 changes: 5 additions & 7 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def __init__(
zero_init_output = False,
max_attend_past = None,
qk_norm = False,
scale_init_value = None,
qk_norm_max_scale = 20,
one_kv_head = False,
shared_kv = False,
value_dim_head = None
Expand Down Expand Up @@ -546,8 +546,8 @@ def __init__(
# cosine sim attention
self.qk_norm = qk_norm
if qk_norm:
scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1))
self.qk_norm_max_scale = qk_norm_max_scale

# talking heads
self.talking_heads = talking_heads
Expand Down Expand Up @@ -643,7 +643,7 @@ def forward(

if self.qk_norm:
q, k = map(l2norm, (q, k))
scale = 1 / (self.scale.exp().clamp(min = 1e-2))
scale = self.scale.sigmoid() * self.qk_norm_max_scale

kv_einsum_eq = 'b h j d' if not self.one_kv_head else 'b j d'

Expand Down Expand Up @@ -760,7 +760,6 @@ def __init__(
shift_tokens = 0,
sandwich_norm = False,
use_qk_norm_attn = False,
qk_norm_attn_seq_len = None,
zero_init_branch_output = False,
**kwargs
):
Expand Down Expand Up @@ -821,8 +820,7 @@ def __init__(
# qk normalization

if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
attn_kwargs = {**attn_kwargs, 'qk_norm': True}

# zero init

Expand Down

0 comments on commit e1a9bfc

Please sign in to comment.