Unofficial PyTorch implementation of the Attention Free Transformer's AFT-Full layer by Zhai, et al. [abs, pdf] from Apple Inc.
You can install aft-pytorch
via pip
:
pip install aft-pytorch
You can import the "AFT-Full" layer (as described in the paper) from the package like so:
from aft_pytorch import AFTFullAttention
layer = AFTFullAttention(
dim=512,
hidden_dim=64,
heads=8
)
# a batch of sequences with 10 timesteps of length 512 each
x = torch.rand(32, 10, 512)
y = layer(x) # [32, 10, 512]
This layer wrapper is a 'plug-and-play' with your existing networks / Transformers. You can swap out the Self-Attention layer with the
AFTFullAttention
layer with minimal changes.
- Add full AFT architecture
- Add variants like AFT-Simple, AFT-Conv, AFT-Local
If you like this repo, please leave a star! If there are any amends or suggestions, feel free to raise a PR/issue.
@misc{
zhai2021an,
title={An Attention Free Transformer},
author={Shuangfei Zhai and Walter Talbott and Nitish Srivastava and Chen Huang and Hanlin Goh and Joshua M. Susskind},
year={2021},
url={https://openreview.net/forum?id=pW--cu2FCHY}
}