Skip to content

Commit

Permalink
add enhanced recurrence, from Ernie-doc paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 3, 2021
1 parent f43d7db commit 989c2ca
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 35 deletions.
78 changes: 45 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,40 @@ logits2, mems2 = model_xl(seg2, mems = mems1, return_mems = True)
logits3, mems3 = model_xl(seg3, mems = mems2, return_mems = True)
```

### Enhanced recurrence

<img src="./images/enhanced-recurrence.png" width="400px"/>

<a href="https://arxiv.org/abs/2012.15688">This paper</a> proposes a simple technique to enhance the range of Transformer-XL. They simply route the memory segment of a layer to the layer below it, for the next recurrent step. You can enable this by setting `shift_mem_down = 1`. You can also shift down arbitrary number of layers by setting this value to `> 1`.

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model_xl = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 512,
max_mem_len = 2048,
shift_mem_down = 1,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
rotary_pos_emb = True
)
)

seg1 = torch.randint(0, 20000, (1, 512))
seg2 = torch.randint(0, 20000, (1, 512))
seg3 = torch.randint(0, 20000, (1, 512))

logits1, mems1 = model_xl(seg1, return_mems = True)
logits2, mems2 = model_xl(seg2, mems = mems1, return_mems = True)
logits3, mems3 = model_xl(seg3, mems = mems2, return_mems = True)

len(mems1), len(mem2), len(mem3) # (5, 5, 5) instead of (6, 6, 6)
```

### Gated residual

<img src="./images/gating.png" width="500px"></img>
Expand Down Expand Up @@ -759,29 +793,6 @@ model = TransformerWrapper(
)
```

## Todo

To be explained and documented

- [x] ~~memory key / values - All-attention paper~~
- [x] ~~memory tokens - Memory Transformers~~
- [x] ~~scale normalization - Transformers Without Tears~~
- [x] ~~feedforward gated linear variant - Noam's GLU Variants~~
- [x] ~~rezero - Rezero is all you need~~
- [x] ~~topk attention - Explicit Sparse Attention~~
- [x] ~~entmax15 instead of softmax - Adaptively Sparse Transformers~~
- [x] ~~mixing head information - Noam's Talking Heads~~
- [x] ~~gating multi-head attention output - Attention on Attention~~
- [x] ~~simplified relative positional encoding bias - T5~~
- [x] ~~sandwich transformer - Reordering Sublayers~~
- [x] ~~wrapper for processing images - Vision Transformer~~
- [x] ~~macaron layers - 'Multi-particle Dynamic System' paper~~
- [x] ~~residual attention - Realformer paper~~
- [x] ~~position infused attention - Shortformer paper~~
- [x] ~~recurrence - Transformer-XL~~
- [x] ~~gated transformer-xl - Stabilizing Transformers for RL~~
- [ ] reversibility - Reformer

## Miscellaneous

Cross Attention
Expand Down Expand Up @@ -840,16 +851,6 @@ model(x, mask = mask) # (1, 1024, 100)
}
```

```bibtex
@inproceedings{kitaev2020reformer,
title = {Reformer: The Efficient Transformer},
author = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=rkgNKkHtvB}
}
```

```bibtex
@article{DBLP:journals/corr/abs-1907-01470,
author = {Sainbayar Sukhbaatar and
Expand Down Expand Up @@ -1167,4 +1168,15 @@ model(x, mask = mask) # (1, 1024, 100)
}
```

```bibtex
@misc{ding2021erniedoc,
title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
year = {2021},
eprint = {2012.15688},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
Binary file added images/enhanced-recurrence.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.18.0',
version = '0.19.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
5 changes: 4 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def forward(

if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0)
layer_mem = mems.pop(0) if mems else None

residual = x

Expand Down Expand Up @@ -876,6 +876,7 @@ def __init__(
attn_layers,
emb_dim = None,
max_mem_len = 0.,
shift_mem_down = 0,
emb_dropout = 0.,
num_memory_tokens = None,
tie_embedding = False,
Expand All @@ -889,6 +890,7 @@ def __init__(

self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down

self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
Expand Down Expand Up @@ -947,6 +949,7 @@ def forward(
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
new_mems = new_mems[self.shift_mem_down:]
return out, new_mems

if return_attn:
Expand Down

0 comments on commit 989c2ca

Please sign in to comment.