-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MUGEN] Add caching to MultiheadAttention #147
Conversation
[ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/langong347/9/base #147 +/- ##
=======================================================
Coverage ? 92.56%
=======================================================
Files ? 45
Lines ? 2769
Branches ? 0
=======================================================
Hits ? 2563
Misses ? 206
Partials ? 0 Continue to review full report at Codecov.
|
[ghstack-poisoned]
@pytest.fixture | ||
def full_attn(input_shape): | ||
# TODO: retire causal once mask generation is moved out of FullAttention | ||
# causal inside FullAttention does not affect caching of k, v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move this TODO comment outside of test and in FullAttention for easier tracking?
def test_split_multihead(self, input_shape, multihead_attn, full_attn): | ||
mha = multihead_attn(False, full_attn) | ||
x = torch.randn(1, *input_shape, 6) # (b, d1, ..., dn, c) | ||
mha.n_head = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could make n_head
a parameter in create_multihead_attn
so you can remove fixture and you don't have to modify it here
self.cache["k"] = torch.cat([k_, k], dim=2) | ||
self.cache["v"] = torch.cat([v_, v], dim=2) | ||
# override the present k, v with the cache | ||
k, v = self.cache["k"], self.cache["v"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so if causal is False, we will continuously use the first keys and values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the non-causal case, k
and v
will be on the start the entire sequence. We should just compute their projection once (this logic is missing in MUGEN) and cache them once to be used later.
In the causal case, k
and v
will be step-wise so the cache will be incremental.
@@ -224,17 +217,14 @@ def forward( | |||
q: Tensor, | |||
k: Tensor, | |||
v: Tensor, | |||
decode_step: Optional[int] = None, | |||
decode_idx: Optional[Iterable[int]] = None, | |||
) -> Tensor: | |||
mask = torch.Tensor(self.mask) if self.causal else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask is already created for you when you're decoding. what about the case when a user specifies an attention mask during encoding? I am thinking of making FullAttention
the base self-attention class for all models. But in FLAVA, ALBEF there is attention masking used in the encoders. how do we reconcile this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By "mask created while decoding" Are you referring to:
multimodal/torchmultimodal/modules/layers/attention.py
Lines 218 to 220 in e274207
if self.causal: | |
seq_len = int(torch.prod(torch.tensor(shape)).item()) | |
self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len))) |
And slicing the masks for decoding:
multimodal/torchmultimodal/modules/layers/attention.py
Lines 231 to 235 in e274207
if decode_step is not None and mask is not None: | |
mask = mask[[decode_step]] | |
elif mask is not None and q.size(2) < mask.size(0): | |
mask = mask[range(q.size(2)), :][:, range(q.size(2))] |
?
I think mask (either the full mask or sliced for decoding) should be created by a top level module (either a transformer encoder or decoder) and passed in to the attention modules instead of being created inside the attention module which is not only inflexible but incurring the additional argument shape
.
Is this consistent with your thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree with that, I think it's better design and it's consistent with other models. FLAVA passes in the attention mask at the highest level at FLAVAModel
's forward, for example. I can move the mask out of the attention modules in a separate PR for unifying SelfAttention
and MultiHeadAttention
.
actually, if the mask were moved out, would it be created in the GPT decoder class? what would be your plan there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to move the mask out of the attention module. I was thinking about doing that previously as the transformer decoder layers will depend on that.
Summary: Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step. - Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive - Concat the present state with past state instead of keeping track of the decoding indices - Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests Test Plan: ``` $ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py -vv ================================================= test session starts ================================================== platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python cachedir: .pytest_cache rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml plugins: cov-3.0.0 collected 11 items test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED [ 9%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED [ 18%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED [ 27%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED [ 36%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED [ 45%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_error_causal_axial_attention PASSED [ 54%] test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED [ 63%] test/modules/layers/test_attention.py::test_full_attention PASSED [ 72%] test/modules/layers/test_attention.py::test_axial_attention PASSED [ 81%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED [ 90%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED [100%] ---------- coverage: platform darwin, python 3.8.13-final-0 ---------- Name Stmts Miss Cover --------------------------------------------------------------------------- torchmultimodal/modules/layers/__init__.py 0 0 100% torchmultimodal/modules/layers/attention.py 109 5 95% torchmultimodal/modules/layers/codebook.py 83 83 0% torchmultimodal/modules/layers/conv.py 74 74 0% torchmultimodal/modules/layers/mlp.py 23 23 0% torchmultimodal/modules/layers/normalizations.py 8 8 0% torchmultimodal/modules/layers/position_embedding.py 32 32 0% torchmultimodal/modules/layers/transformer.py 130 130 0% torchmultimodal/modules/layers/transformer_decoder.py 55 55 0% --------------------------------------------------------------------------- TOTAL 514 410 20% ==================== 11 passed in 1.71s ================ ``` [ghstack-poisoned]
Summary: Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step. - Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive - Concat the present state with past state instead of keeping track of the decoding indices - Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests Test Plan: ``` $ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py -vv -rP ============================================================ test session starts ============================================================ platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python cachedir: .pytest_cache rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml plugins: mock-3.8.2, cov-3.0.0 collected 10 items test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED [ 10%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED [ 20%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED [ 30%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED [ 40%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED [ 50%] test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED [ 60%] test/modules/layers/test_attention.py::test_full_attention PASSED [ 70%] test/modules/layers/test_attention.py::test_axial_attention PASSED [ 80%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED [ 90%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED [100%] ================================================================== PASSES =================================================================== ---------- coverage: platform darwin, python 3.8.13-final-0 ---------- Name Stmts Miss Cover --------------------------------------------------------------------------- torchmultimodal/modules/layers/__init__.py 0 0 100% torchmultimodal/modules/layers/attention.py 110 6 95% torchmultimodal/modules/layers/codebook.py 83 83 0% torchmultimodal/modules/layers/conv.py 74 74 0% torchmultimodal/modules/layers/mlp.py 23 23 0% torchmultimodal/modules/layers/normalizations.py 8 8 0% torchmultimodal/modules/layers/position_embedding.py 32 32 0% torchmultimodal/modules/layers/transformer.py 130 130 0% torchmultimodal/modules/layers/transformer_decoder.py 55 55 0% --------------------------------------------------------------------------- TOTAL 515 411 20% ============================== 10 passed in 2.96s =============================== ``` [ghstack-poisoned]
Summary: Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step. - Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive - Concat the present state with past state instead of keeping track of the decoding indices - Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests Test Plan: ``` $ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py -vv -rP ============================================================ test session starts ============================================================ platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python cachedir: .pytest_cache rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml plugins: mock-3.8.2, cov-3.0.0 collected 10 items test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED [ 10%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED [ 20%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED [ 30%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED [ 40%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED [ 50%] test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED [ 60%] test/modules/layers/test_attention.py::test_full_attention PASSED [ 70%] test/modules/layers/test_attention.py::test_axial_attention PASSED [ 80%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED [ 90%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED [100%] ================================================================== PASSES =================================================================== ---------- coverage: platform darwin, python 3.8.13-final-0 ---------- Name Stmts Miss Cover --------------------------------------------------------------------------- torchmultimodal/modules/layers/__init__.py 0 0 100% torchmultimodal/modules/layers/attention.py 110 6 95% torchmultimodal/modules/layers/codebook.py 83 83 0% torchmultimodal/modules/layers/conv.py 74 74 0% torchmultimodal/modules/layers/mlp.py 23 23 0% torchmultimodal/modules/layers/normalizations.py 8 8 0% torchmultimodal/modules/layers/position_embedding.py 32 32 0% torchmultimodal/modules/layers/transformer.py 130 130 0% torchmultimodal/modules/layers/transformer_decoder.py 55 55 0% --------------------------------------------------------------------------- TOTAL 515 411 20% ============================== 10 passed in 2.96s =============================== ``` [ghstack-poisoned]
Summary: Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step. - Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive - Concat the present state with past state instead of keeping track of the decoding indices - Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests Test Plan: ``` $ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py -vv -rP ============================================================ test session starts ============================================================ platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python cachedir: .pytest_cache rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml plugins: mock-3.8.2, cov-3.0.0 collected 10 items test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED [ 10%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED [ 20%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED [ 30%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED [ 40%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED [ 50%] test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED [ 60%] test/modules/layers/test_attention.py::test_full_attention PASSED [ 70%] test/modules/layers/test_attention.py::test_axial_attention PASSED [ 80%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED [ 90%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED [100%] ================================================================== PASSES =================================================================== ---------- coverage: platform darwin, python 3.8.13-final-0 ---------- Name Stmts Miss Cover --------------------------------------------------------------------------- torchmultimodal/modules/layers/__init__.py 0 0 100% torchmultimodal/modules/layers/attention.py 110 6 95% torchmultimodal/modules/layers/codebook.py 83 83 0% torchmultimodal/modules/layers/conv.py 74 74 0% torchmultimodal/modules/layers/mlp.py 23 23 0% torchmultimodal/modules/layers/normalizations.py 8 8 0% torchmultimodal/modules/layers/position_embedding.py 32 32 0% torchmultimodal/modules/layers/transformer.py 130 130 0% torchmultimodal/modules/layers/transformer_decoder.py 55 55 0% --------------------------------------------------------------------------- TOTAL 515 411 20% ============================== 10 passed in 2.96s =============================== ``` [ghstack-poisoned]
@langong347 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: #147 Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step. - Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive - Concat the present state with past state instead of keeping track of the decoding indices - Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests Test Plan: ``` $ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py -vv -rP ============================================================ test session starts ============================================================ platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python cachedir: .pytest_cache rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml plugins: mock-3.8.2, cov-3.0.0 collected 10 items test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED [ 10%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED [ 20%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED [ 30%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED [ 40%] test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED [ 50%] test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED [ 60%] test/modules/layers/test_attention.py::test_full_attention PASSED [ 70%] test/modules/layers/test_attention.py::test_axial_attention PASSED [ 80%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED [ 90%] test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED [100%] ================================================================== PASSES =================================================================== ---------- coverage: platform darwin, python 3.8.13-final-0 ---------- Name Stmts Miss Cover --------------------------------------------------------------------------- torchmultimodal/modules/layers/__init__.py 0 0 100% torchmultimodal/modules/layers/attention.py 110 6 95% torchmultimodal/modules/layers/codebook.py 83 83 0% torchmultimodal/modules/layers/conv.py 74 74 0% torchmultimodal/modules/layers/mlp.py 23 23 0% torchmultimodal/modules/layers/normalizations.py 8 8 0% torchmultimodal/modules/layers/position_embedding.py 32 32 0% torchmultimodal/modules/layers/transformer.py 130 130 0% torchmultimodal/modules/layers/transformer_decoder.py 55 55 0% --------------------------------------------------------------------------- TOTAL 515 411 20% ============================== 10 passed in 2.96s =============================== ``` Reviewed By: ebsmothers Differential Revision: D37865352 Pulled By: langong347 fbshipit-source-id: 8dea47616cf5054db7cea9689b6c1e6348354e33
Summary:
Add
use_cache
flag toMultiheadAttention.forward
. Whenuse_cache
isTrue
, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.decoding_step
,decoding_index
from MUGEN codebase foruse_cache
directivetest_attention.py
fromunittest
topytest
and rearrange the existing testsTest Plan:
Stack from ghstack (oldest at bottom):
Differential Revision: D37865352