Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MUGEN] Add caching to MultiheadAttention #147

Closed
wants to merge 6 commits into from

Conversation

langong347
Copy link
Contributor

@langong347 langong347 commented Jul 13, 2022

Summary:
Add use_cache flag to MultiheadAttention.forward. When use_cache is True, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.

  • Retire the usage of decoding_step, decoding_index from MUGEN codebase for use_cache directive
  • Concat the present state with past state instead of keeping track of the decoding indices
  • Port test_attention.py from unittest to pytest and rearrange the existing tests

Test Plan:

$ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py  -vv -rP
============================================================ test session starts ============================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml
plugins: mock-3.8.2, cov-3.0.0
collected 10 items

test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED                                            [ 10%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED                                          [ 20%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED                                       [ 30%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED                             [ 40%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED                      [ 50%]
test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED                                                       [ 60%]
test/modules/layers/test_attention.py::test_full_attention PASSED                                                                     [ 70%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                                                    [ 80%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                                                [ 90%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                                            [100%]

================================================================== PASSES ===================================================================

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                    Stmts   Miss  Cover
---------------------------------------------------------------------------
torchmultimodal/modules/layers/__init__.py                  0      0   100%
torchmultimodal/modules/layers/attention.py               110      6    95%
torchmultimodal/modules/layers/codebook.py                 83     83     0%
torchmultimodal/modules/layers/conv.py                     74     74     0%
torchmultimodal/modules/layers/mlp.py                      23     23     0%
torchmultimodal/modules/layers/normalizations.py            8      8     0%
torchmultimodal/modules/layers/position_embedding.py       32     32     0%
torchmultimodal/modules/layers/transformer.py             130    130     0%
torchmultimodal/modules/layers/transformer_decoder.py      55     55     0%
---------------------------------------------------------------------------
TOTAL                                                     515    411    20%

============================== 10 passed in 2.96s ===============================

Stack from ghstack (oldest at bottom):

Differential Revision: D37865352

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 13, 2022
@codecov-commenter
Copy link

codecov-commenter commented Jul 13, 2022

Codecov Report

❗ No coverage uploaded for pull request base (gh/langong347/9/base@cf272c1). Click here to learn what that means.
The diff coverage is n/a.

@@                   Coverage Diff                   @@
##             gh/langong347/9/base     #147   +/-   ##
=======================================================
  Coverage                        ?   92.56%           
=======================================================
  Files                           ?       45           
  Lines                           ?     2769           
  Branches                        ?        0           
=======================================================
  Hits                            ?     2563           
  Misses                          ?      206           
  Partials                        ?        0           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cf272c1...5a79b70. Read the comment docs.

@pytest.fixture
def full_attn(input_shape):
# TODO: retire causal once mask generation is moved out of FullAttention
# causal inside FullAttention does not affect caching of k, v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this TODO comment outside of test and in FullAttention for easier tracking?

def test_split_multihead(self, input_shape, multihead_attn, full_attn):
mha = multihead_attn(False, full_attn)
x = torch.randn(1, *input_shape, 6) # (b, d1, ..., dn, c)
mha.n_head = 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could make n_head a parameter in create_multihead_attn so you can remove fixture and you don't have to modify it here

self.cache["k"] = torch.cat([k_, k], dim=2)
self.cache["v"] = torch.cat([v_, v], dim=2)
# override the present k, v with the cache
k, v = self.cache["k"], self.cache["v"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if causal is False, we will continuously use the first keys and values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-causal case, k and v will be on the start the entire sequence. We should just compute their projection once (this logic is missing in MUGEN) and cache them once to be used later.

In the causal case, k and v will be step-wise so the cache will be incremental.

@@ -224,17 +217,14 @@ def forward(
q: Tensor,
k: Tensor,
v: Tensor,
decode_step: Optional[int] = None,
decode_idx: Optional[Iterable[int]] = None,
) -> Tensor:
mask = torch.Tensor(self.mask) if self.causal else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask is already created for you when you're decoding. what about the case when a user specifies an attention mask during encoding? I am thinking of making FullAttention the base self-attention class for all models. But in FLAVA, ALBEF there is attention masking used in the encoders. how do we reconcile this?

Copy link
Contributor Author

@langong347 langong347 Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "mask created while decoding" Are you referring to:

if self.causal:
seq_len = int(torch.prod(torch.tensor(shape)).item())
self.register_buffer("mask", torch.tril(torch.ones(seq_len, seq_len)))

And slicing the masks for decoding:

if decode_step is not None and mask is not None:
mask = mask[[decode_step]]
elif mask is not None and q.size(2) < mask.size(0):
mask = mask[range(q.size(2)), :][:, range(q.size(2))]

?

I think mask (either the full mask or sliced for decoding) should be created by a top level module (either a transformer encoder or decoder) and passed in to the attention modules instead of being created inside the attention module which is not only inflexible but incurring the additional argument shape.

Is this consistent with your thoughts?

Copy link
Contributor

@RdoubleA RdoubleA Jul 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree with that, I think it's better design and it's consistent with other models. FLAVA passes in the attention mask at the highest level at FLAVAModel's forward, for example. I can move the mask out of the attention modules in a separate PR for unifying SelfAttention and MultiHeadAttention.

actually, if the mask were moved out, would it be created in the GPT decoder class? what would be your plan there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to move the mask out of the attention module. I was thinking about doing that previously as the transformer decoder layers will depend on that.

Summary:
Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.
- Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive
- Concat the present state with past state instead of keeping track of the decoding indices
- Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests

Test Plan:
```
$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py  -vv
================================================= test session starts ==================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml
plugins: cov-3.0.0
collected 11 items

test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED                       [  9%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED                     [ 18%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED                  [ 27%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED        [ 36%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED [ 45%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_error_causal_axial_attention PASSED          [ 54%]
test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED                                  [ 63%]
test/modules/layers/test_attention.py::test_full_attention PASSED                                                [ 72%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                               [ 81%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                           [ 90%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                       [100%]

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                    Stmts   Miss  Cover
---------------------------------------------------------------------------
torchmultimodal/modules/layers/__init__.py                  0      0   100%
torchmultimodal/modules/layers/attention.py               109      5    95%
torchmultimodal/modules/layers/codebook.py                 83     83     0%
torchmultimodal/modules/layers/conv.py                     74     74     0%
torchmultimodal/modules/layers/mlp.py                      23     23     0%
torchmultimodal/modules/layers/normalizations.py            8      8     0%
torchmultimodal/modules/layers/position_embedding.py       32     32     0%
torchmultimodal/modules/layers/transformer.py             130    130     0%
torchmultimodal/modules/layers/transformer_decoder.py      55     55     0%
---------------------------------------------------------------------------
TOTAL                                                     514    410    20%

==================== 11 passed in 1.71s ================
```




[ghstack-poisoned]
Summary:
Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.
- Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive
- Concat the present state with past state instead of keeping track of the decoding indices
- Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests

Test Plan:
```
$ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py  -vv -rP
============================================================ test session starts ============================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml
plugins: mock-3.8.2, cov-3.0.0
collected 10 items

test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED                                            [ 10%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED                                          [ 20%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED                                       [ 30%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED                             [ 40%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED                      [ 50%]
test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED                                                       [ 60%]
test/modules/layers/test_attention.py::test_full_attention PASSED                                                                     [ 70%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                                                    [ 80%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                                                [ 90%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                                            [100%]

================================================================== PASSES ===================================================================

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                    Stmts   Miss  Cover
---------------------------------------------------------------------------
torchmultimodal/modules/layers/__init__.py                  0      0   100%
torchmultimodal/modules/layers/attention.py               110      6    95%
torchmultimodal/modules/layers/codebook.py                 83     83     0%
torchmultimodal/modules/layers/conv.py                     74     74     0%
torchmultimodal/modules/layers/mlp.py                      23     23     0%
torchmultimodal/modules/layers/normalizations.py            8      8     0%
torchmultimodal/modules/layers/position_embedding.py       32     32     0%
torchmultimodal/modules/layers/transformer.py             130    130     0%
torchmultimodal/modules/layers/transformer_decoder.py      55     55     0%
---------------------------------------------------------------------------
TOTAL                                                     515    411    20%

============================== 10 passed in 2.96s ===============================
```




[ghstack-poisoned]
Summary:
Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.
- Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive
- Concat the present state with past state instead of keeping track of the decoding indices
- Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests

Test Plan:
```
$ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py  -vv -rP
============================================================ test session starts ============================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml
plugins: mock-3.8.2, cov-3.0.0
collected 10 items

test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED                                            [ 10%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED                                          [ 20%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED                                       [ 30%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED                             [ 40%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED                      [ 50%]
test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED                                                       [ 60%]
test/modules/layers/test_attention.py::test_full_attention PASSED                                                                     [ 70%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                                                    [ 80%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                                                [ 90%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                                            [100%]

================================================================== PASSES ===================================================================

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                    Stmts   Miss  Cover
---------------------------------------------------------------------------
torchmultimodal/modules/layers/__init__.py                  0      0   100%
torchmultimodal/modules/layers/attention.py               110      6    95%
torchmultimodal/modules/layers/codebook.py                 83     83     0%
torchmultimodal/modules/layers/conv.py                     74     74     0%
torchmultimodal/modules/layers/mlp.py                      23     23     0%
torchmultimodal/modules/layers/normalizations.py            8      8     0%
torchmultimodal/modules/layers/position_embedding.py       32     32     0%
torchmultimodal/modules/layers/transformer.py             130    130     0%
torchmultimodal/modules/layers/transformer_decoder.py      55     55     0%
---------------------------------------------------------------------------
TOTAL                                                     515    411    20%

============================== 10 passed in 2.96s ===============================
```




[ghstack-poisoned]
Summary:
Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.
- Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive
- Concat the present state with past state instead of keeping track of the decoding indices
- Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests

Test Plan:
```
$ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py  -vv -rP
============================================================ test session starts ============================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml
plugins: mock-3.8.2, cov-3.0.0
collected 10 items

test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED                                            [ 10%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED                                          [ 20%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED                                       [ 30%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED                             [ 40%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED                      [ 50%]
test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED                                                       [ 60%]
test/modules/layers/test_attention.py::test_full_attention PASSED                                                                     [ 70%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                                                    [ 80%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                                                [ 90%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                                            [100%]

================================================================== PASSES ===================================================================

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                    Stmts   Miss  Cover
---------------------------------------------------------------------------
torchmultimodal/modules/layers/__init__.py                  0      0   100%
torchmultimodal/modules/layers/attention.py               110      6    95%
torchmultimodal/modules/layers/codebook.py                 83     83     0%
torchmultimodal/modules/layers/conv.py                     74     74     0%
torchmultimodal/modules/layers/mlp.py                      23     23     0%
torchmultimodal/modules/layers/normalizations.py            8      8     0%
torchmultimodal/modules/layers/position_embedding.py       32     32     0%
torchmultimodal/modules/layers/transformer.py             130    130     0%
torchmultimodal/modules/layers/transformer_decoder.py      55     55     0%
---------------------------------------------------------------------------
TOTAL                                                     515    411    20%

============================== 10 passed in 2.96s ===============================
```




[ghstack-poisoned]
@langong347
Copy link
Contributor Author

@langong347 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot deleted the gh/langong347/9/head branch July 19, 2022 14:15
sophiazhi pushed a commit that referenced this pull request Jul 19, 2022
Summary:
Pull Request resolved: #147

Add `use_cache` flag to `MultiheadAttention.forward`. When `use_cache` is `True`, the module will cache the key/value states at each step of decoding such that the query can attend to all key/value states up to the present step.
- Retire the usage of `decoding_step`, `decoding_index` from MUGEN codebase for `use_cache` directive
- Concat the present state with past state instead of keeping track of the decoding indices
- Port `test_attention.py` from `unittest` to `pytest` and rearrange the existing tests

Test Plan:
```
$ (torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers/ test/modules/layers/test_attention.py  -vv -rP
============================================================ test session starts ============================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention, configfile: pyproject.toml
plugins: mock-3.8.2, cov-3.0.0
collected 10 items

test/modules/layers/test_attention.py::TestMultiheadAttention::test_split_multihead PASSED                                            [ 10%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_combine_multihead PASSED                                          [ 20%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention PASSED                                       [ 30%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_use_cache PASSED                             [ 40%]
test/modules/layers/test_attention.py::TestMultiheadAttention::test_multi_head_attention_causal_use_cache PASSED                      [ 50%]
test/modules/layers/test_attention.py::test_scaled_dot_product_attention PASSED                                                       [ 60%]
test/modules/layers/test_attention.py::test_full_attention PASSED                                                                     [ 70%]
test/modules/layers/test_attention.py::test_axial_attention PASSED                                                                    [ 80%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_forward PASSED                                                [ 90%]
test/modules/layers/test_attention.py::TestAxialBlock::test_axial_block_channel_dim PASSED                                            [100%]

================================================================== PASSES ===================================================================

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                    Stmts   Miss  Cover
---------------------------------------------------------------------------
torchmultimodal/modules/layers/__init__.py                  0      0   100%
torchmultimodal/modules/layers/attention.py               110      6    95%
torchmultimodal/modules/layers/codebook.py                 83     83     0%
torchmultimodal/modules/layers/conv.py                     74     74     0%
torchmultimodal/modules/layers/mlp.py                      23     23     0%
torchmultimodal/modules/layers/normalizations.py            8      8     0%
torchmultimodal/modules/layers/position_embedding.py       32     32     0%
torchmultimodal/modules/layers/transformer.py             130    130     0%
torchmultimodal/modules/layers/transformer_decoder.py      55     55     0%
---------------------------------------------------------------------------
TOTAL                                                     515    411    20%

============================== 10 passed in 2.96s ===============================
```

Reviewed By: ebsmothers

Differential Revision: D37865352

Pulled By: langong347

fbshipit-source-id: 8dea47616cf5054db7cea9689b6c1e6348354e33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants