Skip to content

Commit

Permalink
add query-key normalization attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 2, 2021
1 parent a9de3a8 commit 3228fff
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 18 deletions.
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,9 @@ model = TransformerWrapper(
sandwich_norm = True # set this to True
)
)

x = torch.randint(0, 20000, (1, 1024))
model(x)
```

### Normformer
Expand Down Expand Up @@ -883,6 +886,38 @@ model(x)

The last change is a layernorm right after the outwards projection in attention. This is actually identical to the sandwich norm proposed by the Coqview paper, so you can use this by simply setting `sandwich_norm = True`, although it would also add it to the feedforward layer.

## Query-Key Normalization

<img src="./images/cosine-sim-attention.png" width="400px"></img>

This <a href="https://arxiv.org/abs/2010.04245">paper</a> proposes to l2 normalize the queries and keys along the head dimension before the dot product (cosine similarity), with the additional change of the scale being learned rather than static. The normalization prevents the attention operation from overflowing, a perennial problem when training transformers.

This was validated at scale recently by the training of <a href="https://arxiv.org/abs/2111.09883">a 3B parameter vision transformer</a>. The SwinV2 paper also proposes to change the pre-layernorm to a post-layernorm for further stability.

I have validated that this works just as well as dot product attention in an autoregressive setting, if one were to initialize the temperature as proposed in the QK-norm paper (as a function of the sequence length).

You can use it as follows

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
use_qk_norm_attn = True, # set this to True
qk_norm_attn_seq_len = 1024 # set this to max_seq_len from above
)
)

x = torch.randint(0, 20000, (1, 1024))
model(x)
```

## Miscellaneous

Cross Attention
Expand Down Expand Up @@ -1291,4 +1326,26 @@ model(x, mask = mask) # (1, 1024, 100)
}
```

```bibtex
@misc{henry2020querykey,
title = {Query-Key Normalization for Transformers},
author = {Alex Henry and Prudhvi Raj Dachapally and Shubham Pawar and Yuxuan Chen},
year = {2020},
eprint = {2010.04245},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```

```bibtex
@misc{liu2021swin,
title = {Swin Transformer V2: Scaling Up Capacity and Resolution},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
Binary file added images/cosine-sim-attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.20.9',
version = '0.21.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
64 changes: 47 additions & 17 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __call__(self, x, *args, **kwargs):
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

def l2norm(t):
return F.normalize(t, p = 2, dim = -1)

# init helpers

def init_zero_(layer):
Expand Down Expand Up @@ -425,10 +428,13 @@ def __init__(
on_attn = False,
gate_values = False,
zero_init_output = False,
max_attend_past = None
max_attend_past = None,
qk_norm = False,
scale_init_value = None
):
super().__init__()
self.scale = dim_head ** -0.5

self.heads = heads
self.causal = causal
self.max_attend_past = max_attend_past
Expand All @@ -454,6 +460,12 @@ def __init__(
nn.init.constant_(self.to_v_gate.weight, 0)
nn.init.constant_(self.to_v_gate.bias, 1)

# cosine sim attention
self.qk_norm = qk_norm
if qk_norm:
scale_init_value = default(scale_init_value, -3) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)

# talking heads
self.talking_heads = talking_heads
if talking_heads:
Expand Down Expand Up @@ -498,7 +510,7 @@ def forward(
prev_attn = None,
mem = None
):
b, n, _, h, talking_heads, collab_heads, head_scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, x.device, exists(context)
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(context)
kv_input = default(context, x)

q_input = x
Expand Down Expand Up @@ -551,7 +563,11 @@ def forward(
if collab_heads:
k = k.expand(-1, h, -1, -1)

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
if self.qk_norm:
q, k = map(l2norm, (q, k))
scale = 1 / (self.scale.exp().clamp(min = 1e-2))

dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
mask_value = max_neg_value(dots)

if exists(prev_attn):
Expand Down Expand Up @@ -659,6 +675,8 @@ def __init__(
scale_residual = False,
shift_tokens = 0,
sandwich_norm = False,
use_qk_norm_attn = False,
qk_norm_attn_seq_len = None,
zero_init_branch_output = False,
**kwargs
):
Expand Down Expand Up @@ -717,6 +735,12 @@ def __init__(
if macaron:
default_block = ('f',) + default_block

# qk normalization

if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}

# zero init

if zero_init_branch_output:
Expand Down Expand Up @@ -753,7 +777,9 @@ def __init__(

# iterate and construct layers

for layer_type, layer_shift_tokens in zip(self.layer_types, shift_tokens):
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
is_last_layer = ind == (len(self.layer_types) - 1)

if layer_type == 'a':
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
elif layer_type == 'c':
Expand All @@ -775,13 +801,18 @@ def __init__(
residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual = scale_residual)

if sandwich_norm:
norm = nn.ModuleList([norm_fn(), norm_fn()])
else:
norm = norm_fn()
pre_branch_norm = norm_fn() if sandwich_norm and not use_qk_norm_attn else None
post_branch_norm = norm_fn() if sandwich_norm or use_qk_norm_attn else None
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None

norms = nn.ModuleList([
pre_branch_norm,
post_branch_norm,
post_main_norm
])

self.layers.append(nn.ModuleList([
norm,
norms,
layer,
residual
]))
Expand Down Expand Up @@ -819,11 +850,10 @@ def forward(

residual = x

if self.sandwich_norm:
norm, postnorm = norm
pre_branch_norm, post_branch_norm, post_main_norm = norm

if self.pre_norm:
x = norm(x)
if exists(pre_branch_norm):
x = pre_branch_norm(x)

if layer_type == 'a':
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
Expand All @@ -832,8 +862,8 @@ def forward(
elif layer_type == 'f':
out = block(x)

if self.sandwich_norm:
out = postnorm(out)
if exists(post_branch_norm):
out = post_branch_norm(out)

x = residual_fn(out, residual)

Expand All @@ -845,8 +875,8 @@ def forward(
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn

if not self.pre_norm and not is_last:
x = norm(x)
if exists(post_main_norm):
x = post_main_norm(x)

if return_hiddens:
intermediates = LayerIntermediates(
Expand Down

0 comments on commit 3228fff

Please sign in to comment.