Skip to content

Commit

Permalink
Introducing exponential moving average updates to embedding in Quanti…
Browse files Browse the repository at this point in the history
…zation layer (facebookresearch#47)

Summary:
VQVAE implements a codebook of embedding vectors to tokenize an input. It does this by encouraging the encoder output to be close to the embedding vectors during training. The embedding vectors learn this via exponential moving average (EMA) updates that coerce them to become an average of the encoder output. This method of optimization is found to converge faster than MSE codebook loss, as described in the original paper.

This PR introduces EMA updates to the Quantization layer with some refactoring of old unit tests. Additionally, it initializes the embedding weights with random latents from the encoder similar to VideoGPT and Jukebox. In a future PR, codebook restarts will be implemented to alleviate codebook collapse.

Pull Request resolved: facebookresearch#47

Test Plan: Added unit tests for embedding initialization and EMA updates in `test_quantization.py`

Reviewed By: ebsmothers

Differential Revision: D36611814

Pulled By: RdoubleA

fbshipit-source-id: 3cf3d6575878090934f68fba0456365fc1b9c883
  • Loading branch information
RdoubleA authored and facebook-github-bot committed May 24, 2022
1 parent 054ee8f commit 23a2203
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 26 deletions.
108 changes: 94 additions & 14 deletions test/modules/layers/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import unittest

import torch
from test.test_utils import assert_expected
from torch import nn
from test.test_utils import assert_expected, set_rng_seed
from torchmultimodal.modules.layers.quantization import Quantization


Expand All @@ -18,6 +17,7 @@ class TestQuantization(unittest.TestCase):
"""

def setUp(self):
set_rng_seed(4)
self.num_embeddings = 4
self.embedding_dim = 5

Expand All @@ -28,6 +28,7 @@ def setUp(self):
[[2, 2, -1], [1, -1, -2], [0, 0, 0], [1, 2, 1], [1, 0, 0]],
]
)
self.encoded.requires_grad_()
# This is 4x5
self.embedding_weights = torch.Tensor(
[[1, 0, -1, -1, 2], [2, -2, 0, 0, 1], [2, 1, 0, 1, 1], [-1, -2, 0, 2, 0]]
Expand All @@ -38,11 +39,14 @@ def setUp(self):
)

self.vq = Quantization(
num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim
num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
decay=0.3,
)
self.vq.embedding = nn.Embedding.from_pretrained(self.embedding_weights)

def test_quantized_output(self):
self.vq.embedding = self.embedding_weights
self.vq._is_embedding_init = True
output = self.vq(self.encoded)
_, actual_quantized_flat, actual_codebook_indices, actual_quantized = output
# This is shape (2,5,3)
Expand Down Expand Up @@ -84,23 +88,99 @@ def test_preprocess(self):
actual_flat_shape = torch.tensor(encoded_flat.shape)
actual_permuted_shape = torch.tensor(permuted_shape)

assert torch.equal(
actual_flat_shape, expected_flat_shape
), f"actual flattened shape: {actual_flat_shape}, expected flattened shape: {expected_flat_shape}"
assert_expected(actual_flat_shape, expected_flat_shape)

assert torch.equal(
actual_permuted_shape, expected_permuted_shape
), f"actual permuted shape: {actual_permuted_shape}, expected permuted shape: {expected_permuted_shape}"
assert_expected(actual_permuted_shape, expected_permuted_shape)

def test_preprocess_channel_dim_assertion(self):
with self.assertRaises(ValueError):
encoded_flat, permuted_shape = self.vq._preprocess(self.encoded[:, :4, :])
_, _ = self.vq._preprocess(self.encoded[:, :4, :])

def test_postprocess(self):
quantized = self.vq._postprocess(self.input_tensor_flat, torch.Size([2, 2, 3]))
actual_quantized_shape = torch.tensor(quantized.shape)
expected_quantized_shape = torch.tensor([2, 3, 2])

assert torch.equal(
actual_quantized_shape, expected_quantized_shape
), f"actual quantized shape: {actual_quantized_shape}, expected quantized shape: {expected_quantized_shape}"
assert_expected(actual_quantized_shape, expected_quantized_shape)

def test_init_embedding_and_preprocess(self):
assert not self.vq._is_embedding_init, "embedding init flag not False initially"

_, _ = self.vq._init_embedding_and_preprocess(self.encoded)

assert self.vq._is_embedding_init, "embedding init flag not True after init"

actual_weight = self.vq.embedding
expected_weight = torch.Tensor(
[
[2.0, -1.0, 0.0, 2.0, 0.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[0.0, 1.0, -1.0, 2.0, -1.0],
[1.0, 0.0, -1.0, -1.0, 1.0],
]
)
assert_expected(actual_weight, expected_weight)

actual_code_avg = self.vq.code_avg
expected_code_avg = actual_weight
assert_expected(actual_code_avg, expected_code_avg)

actual_code_usage = self.vq.code_usage
expected_code_usage = torch.ones(self.num_embeddings)
assert_expected(actual_code_usage, expected_code_usage)

def test_ema_update_embedding(self):
_ = self.vq(self.encoded)

actual_weight = self.vq.embedding
expected_weight = torch.Tensor(
[
[0.7647, -1.4118, 0.0000, 1.5882, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[-0.4118, 1.4118, -0.5882, 1.1765, -1.4118],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4)

actual_code_avg = self.vq.code_avg
expected_code_avg = torch.Tensor(
[
[1.3000, -2.4000, 0.0000, 2.7000, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[-0.7000, 2.4000, -1.0000, 2.0000, -2.4000],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4)

actual_code_usage = self.vq.code_usage
expected_code_usage = torch.Tensor([1.7000, 1.0000, 1.7000, 1.0000])
assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4)

def test_register_buffer_tensors(self):
out = self.vq(self.encoded)
out.quantized.sum().backward()

msg_has_grad = "tensor assigned to buffer but accumulated grad"
assert not self.vq.code_avg.grad, msg_has_grad
assert not self.vq.code_usage.grad, msg_has_grad
assert not self.vq.embedding.grad, msg_has_grad

assert not list(
self.vq.parameters()
), "buffer variables incorrectly assigned as params"

def test_init_embedding_smaller_encoded(self):
encoded_small = self.encoded[:1, :, :2]
out = self.vq(encoded_small)
encoded_small_flat = out.encoded_flat
embed = self.vq.embedding
# Check for each embedding vector if there is one equal encoded vector + noise
for emb in embed:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=0.01).all()
for enc in encoded_small_flat
]
), "embedding initialized from encoder output incorrectly"
123 changes: 111 additions & 12 deletions torchmultimodal/modules/layers/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,62 @@ class Quantization(nn.Module):
Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.
The embedding weights are trained with exponential moving average updates as described
in original paper.
Code was largely inspired by a PyTorch implementation of the author's original code, found here:
https://colab.research.google.com/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
and by the implementation in MUGEN (Hayes et al. 2022), found here:
https://github.com/mugen-org/MUGEN_baseline/blob/main/lib/models/video_vqvae/vqvae.py
Args:
num_embeddings (int): the number of vectors in the embedding space
embedding_dim (int): the dimensionality of the embedding vectors
Inputs:
x (Tensor): Tensor containing a batch of encoder outputs.
z (Tensor): Tensor containing a batch of encoder outputs.
Expects dimensions to be batch x channel x n dims.
"""

def __init__(self, num_embeddings: int, embedding_dim: int):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
decay: float = 0.99,
epsilon: float = 1e-7,
):
super().__init__()
# Embedding weights and parameters for EMA update will be registered to buffer, as they
# will not be updated by the optimizer but are still model parameters.
# code_usage and code_avg correspond with N and m, respectively, from Oord et al.
randn_init_embedding = torch.randn(num_embeddings, embedding_dim)
self.register_buffer("embedding", randn_init_embedding.clone())
self.register_buffer("code_usage", torch.zeros(num_embeddings))
self.register_buffer("code_avg", randn_init_embedding.clone())

self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings

self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.uniform_(
-1 / self.num_embeddings, 1 / self.num_embeddings
)
self._decay = decay
# Used in Laplace smoothing of code usage
self._epsilon = epsilon

# Flag to track if we need to initialize embedding with encoder output
self._is_embedding_init = False

def _tile(self, x):
# Repeat encoder vectors in cases where the encoder output does not have enough vectors
# to initialize the codebook on first forward pass
num_encoder_vectors, num_channels = x.shape
if num_encoder_vectors < self.embedding_dim:
num_repeats = (
self.num_embeddings + num_encoder_vectors - 1
) // num_encoder_vectors
# Add a small amount of noise to repeated vectors
std = 0.01 / torch.sqrt(torch.tensor(num_channels))
x = x.repeat(num_repeats, 1)
x = x + torch.randn_like(x) * std
return x

def _preprocess(self, encoded: Tensor) -> Tuple[Tensor, Size]:
# Rearrange from batch x channel x n dims to batch x n dims x channel
Expand Down Expand Up @@ -69,24 +107,85 @@ def _postprocess(self, quantized_flat: Tensor, permuted_shape: Size) -> Tensor:

return quantized

def _init_embedding_and_preprocess(self, z: Tensor) -> Tuple[Tensor, Size]:
# Embedding should be initialized with random output vectors from the encoder
# on the first forward pass for faster convergence, as in VideoGPT (Yan et al. 2021)
#
# This requires preprocessing the encoder output, so return this as well.

self._is_embedding_init = True

# Flatten encoder outputs, tile to match num embeddings, get random encoder outputs
encoded_flat, permuted_shape = self._preprocess(z)
encoded_flat_tiled = self._tile(encoded_flat)
idx = torch.randperm(encoded_flat_tiled.shape[0])
encoded_flat_rand = encoded_flat_tiled[idx][: self.num_embeddings]

# Initialize embedding and intermediate values for EMA updates
self.embedding = encoded_flat_rand
self.code_avg = encoded_flat_rand
self.code_usage = torch.ones(self.num_embeddings)

return encoded_flat, permuted_shape

def _ema_update_embedding(self, encoded_flat: Tensor, codebook_indices: Tensor):
# Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average
# of the encoder output. However, we can't compute this in minibatches, so we
# must use exponential moving average.

# Convert indices to one hot encoding
codebook_onehot = nn.functional.one_hot(
codebook_indices, num_classes=self.num_embeddings
).type(torch.float)
# Count how often each embedding vector was looked up
codebook_selection_count = torch.sum(codebook_onehot, 0)
# Update usage value for each embedding vector
self.code_usage = self.code_usage * self._decay + codebook_selection_count * (
1 - self._decay
)
# Laplace smoothing of codebook usage - to prevent zero counts
n = torch.sum(self.code_usage)
self.code_usage = (
(self.code_usage + self._epsilon)
/ (n + self.num_embeddings * self._epsilon)
* n
)
# Get all encoded vectors attracted to each embedding vector
encoded_per_codebook = torch.matmul(codebook_onehot.t(), encoded_flat)
# Update each embedding vector with new encoded vectors that are attracted to it,
# divided by its usage to yield the mean of encoded vectors that choose it
self.code_avg = (
self.code_avg * self._decay + (1 - self._decay) * encoded_per_codebook
)
self.embedding = self.code_avg / self.code_usage.unsqueeze(1)

def _quantize(self, encoded_flat: Tensor) -> Tuple[Tensor, Tensor]:
# Calculate distances from each encoder output vector to each embedding vector, ||x - emb||^2
distances = torch.cdist(encoded_flat, self.embedding.weight, p=2.0) ** 2
# Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2
distances = torch.cdist(encoded_flat, self.embedding, p=2.0) ** 2

# Encoding - select closest embedding vectors
codebook_indices = torch.argmin(distances, dim=1)

# Quantize
quantized_flat = self.embedding(codebook_indices)
quantized_flat = self.embedding[codebook_indices]

# Use exponential moving average to update the embedding instead of a codebook loss,
# as suggested by Oord et al. 2017 and Razavi et al. 2019.
if self.training:
self._ema_update_embedding(encoded_flat, codebook_indices)

# Straight through estimator
quantized_flat = encoded_flat + (quantized_flat - encoded_flat).detach()

return quantized_flat, codebook_indices

def forward(self, x: Tensor) -> QuantizationOutput:
# Reshape and flatten encoder output for quantization
encoded_flat, permuted_shape = self._preprocess(x)
def forward(self, z: Tensor) -> QuantizationOutput:
# First check if embedding is initialized correctly
if not self._is_embedding_init and self.training:
encoded_flat, permuted_shape = self._init_embedding_and_preprocess(z)
else:
# Reshape and flatten encoder output for quantization
encoded_flat, permuted_shape = self._preprocess(z)

# Quantization via nearest neighbor lookup
quantized_flat, codebook_indices = self._quantize(encoded_flat)
Expand Down

0 comments on commit 23a2203

Please sign in to comment.