Skip to content

Commit

Permalink
add noam shazeers "one write head is all you need" one-headed key/val…
Browse files Browse the repository at this point in the history
…ues, used in AlphaCode as well as PaLM
  • Loading branch information
lucidrains committed Apr 11, 2022
1 parent 069fa3e commit 40b1ba7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 50 deletions.
38 changes: 17 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,11 @@ model = TransformerWrapper(
)
```

### Collaborative Attention
### One Write-Head Is All You Need

<img src="./images/collaborative-attention.png" width="500px"></img>
https://arxiv.org/abs/1911.02150

https://arxiv.org/abs/2006.16362

Share redundent learned key/query projections accross heads. Collaborative attention reduces the number of parameters but requires slightly more memory and computation. A good compression factor to match the performance of the vanilla multi-head attention is between 0.25 and 0.5.
Yet another Noam Shazeer paper (he's a legend) that proposes to only have one head for the key / values, but multi-headed queries. This paper was largely ignored for a while, but recently validated at scale in <a href="https://arxiv.org/abs/2203.07814">AlphaCode</a> as well as <a href="https://arxiv.org/abs/2204.02311">PaLM</a>. It has the property of being memory efficient when decoding extremely large language models. You can use it with one keyword argument as shown below.

```python
import torch
Expand All @@ -462,8 +460,7 @@ model = TransformerWrapper(
dim = 512,
depth = 6,
heads = 8,
attn_collab_heads = True,
attn_collab_compression = .3,
attn_one_kv_head = True
)
)
```
Expand Down Expand Up @@ -1244,17 +1241,6 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@misc{cordonnier2020multihead,
title = {Multi-Head Attention: Collaborate Instead of Concatenate},
author = {Jean-Baptiste Cordonnier and Andreas Loukas and Martin Jaggi},
year = {2020},
eprint = {2006.16362},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

```bibtex
@misc{press2020improving,
title = {Improving Transformer Models by Reordering their Sublayers},
Expand Down Expand Up @@ -1532,9 +1518,19 @@ generated = model.generate(start_emb, 17) # (17, 777)

```bibtex
@article{chowdhery2022PaLM,
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Chowdhery, Aakanksha et al},
year = {2022}
title = {PaLM: Scaling Language Modeling with Pathways},
author = {Chowdhery, Aakanksha et al},
year = {2022}
}
```

```bibtex
@article{Shazeer2019FastTD,
title = {Fast Transformer Decoding: One Write-Head is All You Need},
author = {Noam M. Shazeer},
journal = {ArXiv},
year = {2019},
volume = {abs/1911.02150}
}
```

Expand Down
Binary file removed images/collaborative-attention.png
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.25.9',
version = '0.26.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
51 changes: 23 additions & 28 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ def __init__(self, dim):
def forward(self, max_seq_len, device):
t = torch.arange(max_seq_len, device = device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return rearrange(emb, 'n d -> () () n d')
return torch.cat((freqs, freqs), dim=-1)

def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
Expand All @@ -326,7 +325,7 @@ def rotate_half(x):

def apply_rotary_pos_emb(t, freqs):
seq_len = t.shape[-2]
freqs = freqs[:, :, -seq_len:]
freqs = freqs[-seq_len:, :]
return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())

# norms
Expand Down Expand Up @@ -511,8 +510,6 @@ def __init__(
causal = False,
talking_heads = False,
head_scale = False,
collab_heads = False,
collab_compression = .3,
sparse_topk = None,
use_entmax15 = False,
num_mem_kv = 0,
Expand All @@ -522,7 +519,8 @@ def __init__(
zero_init_output = False,
max_attend_past = None,
qk_norm = False,
scale_init_value = None
scale_init_value = None,
one_kv_head = False
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -531,24 +529,24 @@ def __init__(
self.causal = causal
self.max_attend_past = max_attend_past

qk_dim = v_dim = dim_head * heads
q_dim = k_dim = v_dim = out_dim = dim_head * heads

# collaborative heads
self.collab_heads = collab_heads
if self.collab_heads:
qk_dim = int(collab_compression * qk_dim)
self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
self.one_kv_head = one_kv_head
if one_kv_head:
k_dim = v_dim = dim_head
out_dim = v_dim * heads

self.to_q = nn.Linear(dim, qk_dim, bias = False)
self.to_k = nn.Linear(dim, qk_dim, bias = False)
self.to_q = nn.Linear(dim, q_dim, bias = False)
self.to_k = nn.Linear(dim, k_dim, bias = False)
self.to_v = nn.Linear(dim, v_dim, bias = False)

self.dropout = nn.Dropout(dropout)


# add GLU gating for aggregated values, from alphafold2
self.to_v_gate = None
if gate_values:
self.to_v_gate = nn.Linear(dim, v_dim)
self.to_v_gate = nn.Linear(dim, out_dim)
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1)

Expand Down Expand Up @@ -583,7 +581,7 @@ def __init__(

# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)

# init output projection 0
if zero_init_output:
Expand All @@ -602,7 +600,7 @@ def forward(
prev_attn = None,
mem = None
):
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(context)
b, n, _, h, talking_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.head_scale, self.scale, x.device, exists(context)
kv_input = default(context, x)

q_input = x
Expand All @@ -623,12 +621,10 @@ def forward(
k = self.to_k(k_input)
v = self.to_v(v_input)

if not collab_heads:
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
else:
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
k = rearrange(k, 'b n d -> b () n d')
v = rearrange(v, 'b n (h d) -> b h n d', h = h)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)

if not self.one_kv_head:
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (k, v))

if exists(rotary_pos_emb) and not has_context:
l = rotary_pos_emb.shape[-1]
Expand All @@ -652,14 +648,13 @@ def forward(
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value = True)

if collab_heads:
k = k.expand(-1, h, -1, -1)

if self.qk_norm:
q, k = map(l2norm, (q, k))
scale = 1 / (self.scale.exp().clamp(min = 1e-2))

dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
kv_einsum_eq = 'b h j d' if not self.one_kv_head else 'b j d'

dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
mask_value = max_neg_value(dots)

if exists(prev_attn):
Expand Down Expand Up @@ -717,7 +712,7 @@ def forward(
if talking_heads:
attn = self.post_softmax_talking_heads(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

if head_scale:
out = out * self.head_scale_params
Expand Down

0 comments on commit 40b1ba7

Please sign in to comment.