This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch.
Implementation is almost same as the one proposed in the paper, with additional masking and adding bias compatibility, batch dimensions support and PyTorch implementation. For computing attention, the proposed method requires only O(sqrt(n)) memory, and the provided functions can be used as a drop-in replacement for attention calculation.
Important Note: This implementation is a trade-off between memory requirements and runtime, so you should adjust key_chunk_size
and query_chunk_size
parameters to achieve the best configuration for your usecase. Here is a note from the paper's authors:
While a constant chunk size for the queries and a chunk size of sqrt(n) for the keys and values is optimal for memory consumption, the runtime is also affected by the choice of chunk size in practice, which is heavily affected by the choice of hardware. Ultimately, we have to leave this trade-off to the programmer, and expose the chunk sizes as arguments query_chunk_size and key_chunk_size. In Figure 1 we provide default values for the chunk sizes that lead to minimal runtime impact (on TPUv2), while still providing significant memory savings.
- Install the library
# for Jax
pip install memory-efficient-attention[jax]
# for PyTorch
pip install memory-efficient-attention[torch]
# for Running Tests
pip install memory-efficient-attention[testing]
- Compute attention with the proper function
import numpy as np
import time
# for PyTorch
from memory_efficient_attention import efficient_dot_product_attention_pt
# or for Jax
from memory_efficient_attention import efficient_dot_product_attention_jax
# Random Data (batch dimensions are not necessary)
b = 8
query = torch.tensor(np.random.rand(1, b, 128, 16, 8).astype("float32"))
key = torch.tensor(np.random.rand(1, b, 128, 16, 8).astype("float32"))
value = torch.tensor(np.random.rand(1, b, 128, 16, 8).astype("float32"))
# optional, for casual tasks, ...
mask = torch.tensor(np.random.rand(1, b, 16, 128, 128) > 0.5)
bias = torch.tensor(np.random.rand(1, b, 16, 128, 128).astype("float32") / 100)
# Time attention chunking
def timit(repeats=10,key_chunk_size=128, query_chunk_size=128):
total_time_simp=0.0
repeats=10
for _ in range(repeats):
start = time.time()
out=efficient_dot_product_attention_pt(query, key, value, mask, bias, key_chunk_size=key_chunk_size, query_chunk_size=query_chunk_size)
total_time_simp += (time.time() - start)
############
total_time_simp= total_time_simp / repeats
print('attention took:', total_time_simp)
timit(100,128,32)
## Citation
Please cite if this implementation helps your research. You can use the following BibTeX entry:
```bibtex
@misc{memory_efficient_attention,
title = {Memory Efficient Attention},
author = {Rezaei, Amin},
howpublished = {\url{github.com/AminRezaei0x443/memory-efficient-attention}},
year = {2021}
}
Also, for the paper:
@misc{rabe2021selfattention,
title={Self-attention Does Not Need $O(n^2)$ Memory},
author={Markus N. Rabe and Charles Staats},
year={2021},
eprint={2112.05682},
archivePrefix={arXiv},
primaryClass={cs.LG}
}