Skip to content

Commit

Permalink
Spatialtemporal position embedding;tensor slicing utility (facebookre…
Browse files Browse the repository at this point in the history
…search#69)

Summary:
Pull Request resolved: facebookresearch#69

- Spatial-temporal position embedding is the first layer of the multimodal GPT attention stack
- Refactor [`AddBroadcastEmbedding`](https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/gpt/attention.py#L256) and the associated utility [`tesor_slice`](https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/gpt/utils.py#L75)
- Use `pytest` to perform parametrized tests, test setup and etc. as the starting point to migrate our testing framework away from python `unittest`.

Test Plan:
```
(torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/modules/layers test/modules/layers/test_position_embedding.py
================================================================================ test session starts =================================================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0
rootdir: /Users/langong/gpt_attention
plugins: cov-3.0.0
collected 7 items

test/modules/layers/test_position_embedding.py .......                                                                                                                         [100%]

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                                   Stmts   Miss  Cover
--------------------------------------------------------------------------
torchmultimodal/modules/layers/attention.py               97     97     0%
torchmultimodal/modules/layers/codebook.py                81     81     0%
torchmultimodal/modules/layers/conv.py                    74     74     0%
torchmultimodal/modules/layers/mlp.py                     22     22     0%
torchmultimodal/modules/layers/normalizations.py           7      7     0%
torchmultimodal/modules/layers/position_embedding.py      38      0   100%
torchmultimodal/modules/layers/transformer.py            133    133     0%
--------------------------------------------------------------------------
TOTAL                                                    452    414     8%

====================== 7 passed in 2.38s ==================
```

```
(torchmm) langong-mbp:gpt_attention langong$ python -m pytest --cov=torchmultimodal/utils/ test/utils/test_common.py -vv
================================================================================ test session starts =================================================================================
platform darwin -- Python 3.8.13, pytest-7.1.2, pluggy-1.0.0 -- /Users/langong/local/miniconda3/envs/torchmm/bin/python
cachedir: .pytest_cache
rootdir: /Users/langong/gpt_attention
plugins: cov-3.0.0
collected 6 items

test/utils/test_common.py::test_shift_dim PASSED                                                                                                                               [ 16%]
test/utils/test_common.py::TestTensorSlice::test_default PASSED                                                                                                                [ 33%]
test/utils/test_common.py::TestTensorSlice::test_size_minus_one PASSED                                                                                                         [ 50%]
test/utils/test_common.py::TestTensorSlice::test_uneven_begin_size PASSED                                                                                                      [ 66%]
test/utils/test_common.py::TestTensorSlice::test_invalid_begin XFAIL (Invalid begin)                                                                                           [ 83%]
test/utils/test_common.py::TestTensorSlice::test_invalid_size XFAIL (Invalid size)                                                                                             [100%]

---------- coverage: platform darwin, python 3.8.13-final-0 ----------
Name                                Stmts   Miss  Cover
-------------------------------------------------------
torchmultimodal/utils/__init__.py       0      0   100%
torchmultimodal/utils/common.py        66     21    68%
torchmultimodal/utils/file_io.py       10      3    70%
-------------------------------------------------------
TOTAL                                  76     24    68%

===================== 4 passed, 2 xfailed in 1.41s ============================
```

Reviewed By: RdoubleA

Differential Revision: D37100730

Pulled By: langong347

fbshipit-source-id: 1b1d99ff924fe88078e4d7563fcf52d334185dca
  • Loading branch information
langong347 authored and facebook-github-bot committed Jun 13, 2022
1 parent 1d48969 commit a57f654
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 2 deletions.
79 changes: 79 additions & 0 deletions test/modules/layers/test_position_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from test.test_utils import assert_expected
from torch import nn
from torchmultimodal.modules.layers.position_embedding import (
BroadcastedPositionEmbedding,
)


class TestBroadcastedPositionEmbedding:
@pytest.fixture(scope="class")
def pos_emb(self):
return BroadcastedPositionEmbedding(
shape=(1, 2),
embedding_dim=6,
)

def test_init_sets_embedding(self, pos_emb):
"""Test the embeddings are initialized with the correct dimensions"""
expected = [(1, 3), (2, 3)]
for i, (key, _) in enumerate(pos_emb.embedding.items()):
assert_expected(pos_emb.embedding[key].shape, expected[i])

def test_init_bad_embedding_dim(self):
"""Test raising error when the embedding dim is not allowed"""
with pytest.raises(ValueError):
BroadcastedPositionEmbedding(shape=(1, 2), embedding_dim=5)

def test_seq_len(self, pos_emb):
assert_expected(pos_emb.seq_len, 2)

def test_broadcast(self, pos_emb):
"""Test embedding along each dim is broadcasted correctly"""
embedding = [
torch.tensor([[0.0, 1.0, 2.0]]),
torch.tensor([[3.0, 4.0, 5.0], [6.0, 7.0, 8]]),
]
expected = [
torch.tensor([[[[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]]]]),
torch.tensor([[[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]]),
]
for i, emb in enumerate(embedding):
pos_emb.embedding[f"d_{i}"] = nn.Parameter(emb)
assert_expected(pos_emb._broadcast(i), expected[i])

def test_decode(self, pos_emb):
"""Test the embedding at a previous location is selected for each decode step"""
x_shape = (1, 2, 6)
broadcasted_embedding = torch.tensor(
[[[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]]]
)
expected = [
torch.tensor([[[[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]]]]),
torch.tensor([[[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]]]]),
]

for decode_step, _ in enumerate(pos_emb.decode_idxs):
actual = pos_emb._decode(decode_step, broadcasted_embedding, x_shape)
assert_expected(actual, expected[decode_step])

def test_forward(self, pos_emb):
expected = (1, 2, 6)
assert_expected(pos_emb().shape, expected)

def test_forward_decode(self, pos_emb):
"""Test the decode statement inside ``forward`` is hit when ``decode_step`` is given"""
x = torch.zeros(1, *(pos_emb.shape), pos_emb.embedding_dim).flatten(
start_dim=1, end_dim=-2
)
actual = pos_emb(x, decode_step=0).shape
expected = (1, 1, 6)
assert_expected(actual, expected)
39 changes: 38 additions & 1 deletion test/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from test.test_utils import assert_expected

from torchmultimodal.utils.common import shift_dim
from torchmultimodal.utils.common import shift_dim, tensor_slice


def test_shift_dim():
Expand All @@ -19,3 +21,38 @@ def test_shift_dim():
actual = shift_dim(test_random_tensor, -3, 3)
expected = test_random_tensor.permute(0, 1, 3, 2, 4).contiguous()
assert_expected(actual, expected)


class TestTensorSlice:
@pytest.fixture(scope="class")
def test_input(self):
return torch.tensor([[[0, 1], [2, 3], [5, 6]]])

def test_default(self, test_input):
actual = tensor_slice(test_input, [0, 1, 0], [1, 1, 2])
expected = torch.tensor([[[2, 3]]])
assert_expected(actual, expected)

def test_size_minus_one(self, test_input):
"""Test size -1"""
actual = tensor_slice(test_input, [0, 1, 0], [1, -1, 2])
expected = torch.tensor([[[2, 3], [5, 6]]])
assert_expected(actual, expected)

def test_uneven_begin_size(self, test_input):
"""Test uneven begin and size vectors"""
actual = tensor_slice(test_input, [0, 1, 0], [1, 1])
expected = torch.tensor([[[2, 3]]])
assert_expected(actual, expected)

actual = tensor_slice(test_input, [0, 1], [1, 1, 2])
expected = torch.tensor([[[2, 3]]])
assert_expected(actual, expected)

@pytest.mark.xfail(raises=ValueError, reason="Invalid begin")
def test_invalid_begin(self, test_input):
tensor_slice(test_input, [-1, 1, 0], [1, 1, 2])

@pytest.mark.xfail(raises=ValueError, reason="Invalid size")
def test_invalid_size(self, test_input):
tensor_slice(test_input, [0, 1, 0], [-2, 1, 2])
132 changes: 132 additions & 0 deletions torchmultimodal/modules/layers/position_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from typing import List, Optional, Tuple

import torch
from torch import nn, Tensor
from torchmultimodal.utils.common import tensor_slice


# Reference:
# https://github.com/wilson1yan/VideoGPT/blob/c21cc7e2579f820cb2b90097406d72cf69a46474/videogpt/attention.py#L458
class BroadcastedPositionEmbedding(nn.Module):
r"""Spatiotemporal broadcasted positional embeddings.
Each embedding vector of the ``i``-th dim is repeated by ``N`` times, where
:math:`N = \prod_{j>i}\text{dim}[j]`.
Args:
shape (Tuple[int, ...]): shape of raw data before batching and embedding
embedding_dim (int): the size of each embedding vector
Raises:
ValueError: if ``embedding_dim`` is not an integer multiple of ``len(shape)``
Inputs:
x (Optional[Tensor]): flattened input data, e.g., ``(batch, time * height * width, embedding_dim)``.
decode_step (Optional[int]): position of the data that requires decoding.
"""

def __init__(
self,
shape: Tuple[int, ...],
embedding_dim: int,
) -> None:
super().__init__()
if embedding_dim % len(shape) != 0:
raise ValueError(
f"Embedding dim {embedding_dim} modulo len(shape) {len(shape)} is not zero"
)

self.shape = shape
self.n_dim = n_dim = len(shape)
self.embedding_dim = embedding_dim

self.embedding = nn.ParameterDict(
{
f"d_{i}": nn.Parameter(
torch.randn(shape[i], embedding_dim // n_dim) * 0.01
)
for i in range(n_dim)
}
)

@property
def seq_len(self) -> int:
"""Dimension of flattened data, e.g., time * height * width"""
return int(torch.prod(torch.tensor(self.shape)).item())

@property
def decode_idxs(self) -> List:
"""Indices along the dims of data, e.g., ``(time, height, width)``."""
return list(itertools.product(*[range(s) for s in self.shape]))

def _broadcast(self, i: int) -> Tensor:
"""Broadcasts the ``i``-th embedding matrix ``(self.shape[i], self.embedding_dim // n_dim)`` along the other
dims of ``self.shape``. The embedding dim is not touched.
For example::
>>> pos_emb = BroadcastedPositionEmbedding(shape=(2, 4), embedding_dim=6)
>>> print(pos_emb.embedding["d_0"].shape)
torch.Size([2, 3])
>>> pos_emb.embedding["d_0"] = nn.Parameter(torch.tensor([[0., 0., 0.], [0., 0., 1.]]))
>>> out = pos_emb._broadcast(i=0)
>>> print(out)
tensor([[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]]]])
>>> print(out.shape)
(1, 2, 4, 3)
The input is broadcasted along the second dim ``4`` since it's the ``0``-th embedding constructed w.r.t the
first dim ``2``.
"""
emb = self.embedding[f"d_{i}"]
# (1, 1, ..., 1, self.shape[i], 1, ..., -1)
emb = emb.view(
1,
*itertools.repeat(1, i),
self.shape[i],
*itertools.repeat(1, (self.n_dim - i - 1)),
-1,
)
# (1, *self.shape, -1)
emb = emb.expand(1, *self.shape, -1)

return emb

def _decode(
self, decode_step: int, embeddings: Tensor, x_shape: Tuple[int, ...]
) -> Tensor:
"""Returns the embedding vector immediately before the decoding location."""
decode_idx = self.decode_idxs[decode_step - 1]
embeddings = tensor_slice(
embeddings,
[0, *decode_idx, 0],
[x_shape[0], *itertools.repeat(1, self.n_dim), x_shape[-1]],
)

return embeddings

def forward(
self, x: Optional[Tensor] = None, decode_step: Optional[int] = None
) -> Tensor:
embeddings = []
for i in range(self.n_dim):
emb = self._broadcast(i)
embeddings.append(emb)

embeddings = torch.cat(
embeddings, dim=-1
) # concatenated embeddings: (1, *(shape), embedding_dim)

if decode_step is not None:
embeddings = self._decode(decode_step, embeddings, tuple(x.shape))
# decoded embedding: (1, *repeat(1, len(shape)), embedding_dim)

return embeddings.flatten(start_dim=1, end_dim=-2)
32 changes: 31 additions & 1 deletion torchmultimodal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
from collections import OrderedDict
from dataclasses import fields
from typing import Optional
from typing import List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -65,6 +65,36 @@ def shift_dim(
return x


def tensor_slice(x: Tensor, begin: List[int], size: List[int]) -> Tensor:
"""Slices a tensor dimension-wise.
The input tensor is sliced along each dimension by specifying the starts and
the increments.
Args:
x (Tensor): tensor to be sliced.
begin (List[int]): list of starts corresponding to each dimension.
size (List[int]): list of increments with respect to the starts along each dimension. Specifically,
``-1`` means slicing from begin to the last element (inclusive) of that dimension.
Returns:
The sliced tensor.
Raises:
ValueError: if any of ``begin`` indices is negative
ValueError: if any of ``size`` is less than ``-1``
"""
if not all([b >= 0 for b in begin]):
raise ValueError("All starting indices must be non-negative.")
if not all([s >= -1 for s in size]):
raise ValueError("All sizes must be either non-negative or -1.")

size = [l - b if s == -1 else s for s, b, l in zip(size, begin, x.shape)]

slices = [slice(b, b + s) for b, s in zip(begin, size)]
return x[slices]


class PretrainedMixin:
def get_model_dir(self, url):
return os.path.join(
Expand Down

0 comments on commit a57f654

Please sign in to comment.