From beb8e7b46f1548184d4fc6552417d5daee10ba8a Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 20 Nov 2024 15:47:04 -0800 Subject: [PATCH 1/4] first --- tests/torchtune/modules/test_vq_embeddings.py | 192 ++++++++++++++++++ torchtune/modules/vq_embeddings.py | 165 +++++++++++++++ 2 files changed, 357 insertions(+) create mode 100644 tests/torchtune/modules/test_vq_embeddings.py create mode 100644 torchtune/modules/vq_embeddings.py diff --git a/tests/torchtune/modules/test_vq_embeddings.py b/tests/torchtune/modules/test_vq_embeddings.py new file mode 100644 index 0000000000..317c01574b --- /dev/null +++ b/tests/torchtune/modules/test_vq_embeddings.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import assert_expected +from torch import tensor +from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings + + +@pytest.fixture(autouse=True) +def random_seed(): + torch.manual_seed(4) + + +class TestVectorQuantizedEmbeddings: + @pytest.fixture + def num_embeddings(self): + return 4 + + @pytest.fixture + def embedding_dim(self): + return 3 + + @pytest.fixture + def encoded(self): + # This is 2x5x3 + encoded = tensor( + [ + [ + [-1.0, 0.0, 1.0], + [2.0, 1.0, 0.0], + [0.0, -1.0, -1.0], + [0.0, 2.0, -1.0], + [-2.0, -1.0, 1.0], + ], + [ + [2.0, 2.0, -1.0], + [1.0, -1.0, -2.0], + [0.0, 0.0, 0.0], + [1.0, 2.0, 1.0], + [1.0, 0.0, 0.0], + ], + ] + ) + encoded.requires_grad_() + + return encoded + + @pytest.fixture + def embedding_weights(self): + # This is 4x3 + return tensor( + [ + [1.0, 0.0, -1.0], + [2.0, -2.0, 0.0], + [2.0, 1.0, 0.0], + [-1.0, -2.0, 0.0], + ] + ) + + @pytest.fixture + def input_tensor_flat(self): + # This is 4x3 + return tensor( + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] + ) + + @pytest.fixture + def codebook(self, num_embeddings, embedding_dim): + return VectorQuantizedEmbeddings( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + decay=0.3, + ) + + def test_quantized_output(self, codebook, embedding_weights, encoded): + codebook.embedding = embedding_weights + actual = codebook(encoded) + + # This is shape (2,5,3) + expected_quantized = tensor( + [ + [ + [2.0, 2.0, 1.0], + [1.0, 1.0, 0.0], + [0.0, 0.0, -1.0], + [1.0, 1.0, -1.0], + [1.0, 1.0, 2.0], + ], + [ + [2.0, 2.0, -1.0], + [1.0, -2.0, -2.0], + [0.0, 0.0, 0.0], + [1.0, 0.0, 2.0], + [1.0, 1.0, 0.0], + ], + ] + ) + expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type( + torch.LongTensor + ) + + assert_expected(actual[0], expected_quantized) + assert_expected(actual[1], expected_token_ids) + + def test_ema_update_embedding(self, num_embeddings, embedding_dim, encoded): + codebook = VectorQuantizedEmbeddings( + num_embeddings, embedding_dim, learnable=True + ) + distances = torch.cdist(encoded, codebook.embedding, p=2.0) ** 2 + codebook_indices = torch.argmin(distances, dim=1) + codebook._ema_update_embedding(encoded, codebook_indices) + + actual_weight = codebook.embedding + expected_weight = tensor( + [ + [0.7647, -1.4118, 0.0000, 1.5882, 0.0000], + [2.0000, 1.0000, 0.0000, 1.0000, 1.0000], + [-0.4118, 1.4118, -0.5882, 1.1765, -1.4118], + [1.0000, 0.0000, -1.0000, -1.0000, 1.0000], + ] + ) + assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4) + + actual_code_avg = codebook.code_avg + expected_code_avg = tensor( + [ + [1.3000, -2.4000, 0.0000, 2.7000, 0.0000], + [2.0000, 1.0000, 0.0000, 1.0000, 1.0000], + [-0.7000, 2.4000, -1.0000, 2.0000, -2.4000], + [1.0000, 0.0000, -1.0000, -1.0000, 1.0000], + ] + ) + assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4) + + actual_code_usage = codebook.code_usage + expected_code_usage = tensor([1.7000, 1.0000, 1.7000, 1.0000]) + assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4) + + def test_codebook_restart(self, codebook, encoded): + # Use only embedding vector at index = 1 and force restarts. + # Slightly modify encoded_flat to make sure vectors restart to something new + encoded_noise = encoded + torch.randn_like(encoded) + codebook_indices_low_usage = torch.ones(encoded.shape[0], dtype=torch.long) + codebook._ema_update_embedding(encoded_noise, codebook_indices_low_usage) + + # Check if embedding contains restarts + for i, emb in enumerate(codebook.embedding): + # We used only emb vector with index = 1, so check it was not restarted + if i == 1: + assert_expected( + emb, + codebook.code_avg[1] / codebook.code_usage[1], + rtol=0, + atol=1e-4, + ) + # Compare each embedding vector to each encoded vector. + # If at least one match, then restart happened. + else: + assert any( + [ + torch.isclose(emb, enc, rtol=0, atol=1e-4).all() + for enc in encoded_noise + ] + ), "embedding restarted from encoder output incorrectly" + + def test_lookup(self, codebook, embedding_weights): + codebook.embedding = embedding_weights + indices_flat = tensor([[0, 1]]) # (b, seq_len) + indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape) + actual_quantized_flat = codebook.lookup(indices_flat) + actual_quantized = codebook.lookup(indices_shaped) + expected_quantized_flat = tensor( + [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]] + ) + expected_quantized = tensor( + [ + [ + [[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]], + [[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]], + ] + ] + ) + assert_expected( + actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4 + ) + assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4) diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py new file mode 100644 index 0000000000..f36101e49e --- /dev/null +++ b/torchtune/modules/vq_embeddings.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Tuple + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + + +class VectorQuantizedEmbeddings(nn.Module): + """ + Vector quantized embedding layer that takes in the output of an encoder + and performs a nearest-neighbor lookup in the embedding space. + Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) + to generate high-fidelity images, videos, and audio data. + The embedding weights are trained with exponential moving average updates as described + in original paper. + + Code was adapted from torchmultimodal's `Codebook module + `_. + + Args: + num_embeddings (int): Number of vectors in the embedding space. + embedding_dim (int): Dimensionality of the embedding vectors. + decay (float, optional): Factor used in exponential moving average update of the embeddings. + Defaults to ``0.99``. + codebook_usage_threshold (float, optional): Threshold for the average number of times an embedding vector + is chosen below which it will be re-initialized. Defaults to ``1.0``. + learnable (bool): If True, register embedding weights, codebook usage, and codebook average to buffer + for EMA updates during training. If False, only register embedding weights as an nn.Parameter, for use + in a frozen module. Default is False. + epsilon (float, optional): Noise used in Laplace smoothing of codebook usage. Defaults to ``1e-7``. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + decay: float = 0.99, + codebook_usage_threshold: float = 1.0, + learnable: bool = False, + epsilon: float = 1e-7, + ) -> None: + super().__init__() + # Embedding weights and parameters for EMA update will be registered to buffer, as they + # will not be updated by the optimizer but are still model parameters. + # code_usage and code_avg correspond with N and m, respectively, from Oord et al. + randn_init_embedding = torch.randn(num_embeddings, embedding_dim) + if learnable: + self.register_buffer("embedding", randn_init_embedding.clone()) + self.register_buffer("code_usage", torch.zeros(num_embeddings)) + self.register_buffer("code_avg", randn_init_embedding.clone()) + else: + self.register_parameter("embedding", nn.Parameter(randn_init_embedding)) + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.learnable = learnable + + self._decay = decay + # Used in Laplace smoothing of code usage + self._epsilon = epsilon + + # Threshold for randomly reseting unused embedding vectors + self.codebook_usage_threshold = codebook_usage_threshold + + def _tile(self, x: Tensor, n: int) -> Tensor: + # Repeat vectors in x if x has less than n vectors + num_vectors, num_channels = x.shape + if num_vectors < n: + num_repeats = (n + num_vectors - 1) // num_vectors + # Add a small amount of noise to repeated vectors + std = 0.01 / torch.sqrt(torch.tensor(num_channels)) + x = x.repeat(num_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _get_random_vectors(self, x: Tensor, n: int) -> Tensor: + # Gets n random row vectors from 2D tensor x + x_tiled = self._tile(x, n) + idx = torch.randperm(x_tiled.shape[0]) + x_rand = x_tiled[idx][:n] + return x_rand + + def _ema_update_embedding(self, z: Tensor, codebook_indices: Tensor) -> None: + # Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average + # of the encoder output. However, we can't compute this in minibatches, so we + # must use exponential moving average. + + # Convert indices to one hot encoding + codebook_onehot = nn.functional.one_hot( + codebook_indices, num_classes=self.num_embeddings + ).type(torch.float) + # Count how often each embedding vector was looked up + codebook_selection_count = torch.sum(codebook_onehot, 0) + # Update usage value for each embedding vector + self.code_usage.mul_(self._decay).add_( + codebook_selection_count, alpha=(1 - self._decay) + ) + # Laplace smoothing of codebook usage - to prevent zero counts + n = torch.sum(self.code_usage) + self.code_usage.add_(self._epsilon).divide_( + n + self.num_embeddings * self._epsilon + ).mul_(n) + # Get all encoded vectors attracted to each embedding vector + encoded_per_codebook = torch.matmul(codebook_onehot.t(), z) + # Update each embedding vector with new encoded vectors that are attracted to it, + # divided by its usage to yield the mean of encoded vectors that choose it + self.code_avg.mul_(self._decay).add_( + encoded_per_codebook, alpha=(1 - self._decay) + ) + self.embedding = self.code_avg / self.code_usage.unsqueeze(1) + # Reset any embedding vectors that fall below threshold usage with random encoded vectors + z_rand = self._get_random_vectors(z, self.num_embeddings) + self.embedding = torch.where( + self.code_usage.unsqueeze(1) >= self.codebook_usage_threshold, + self.embedding, + z_rand, + ) + + def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where + b is batch size, s is sequence length or time, and d is ``embedding_dim``. + + Returns: + An instance of :class:`~torchmultimodal.modules.layers.codebook.CodebookOutput`. + """ + if z.shape[-1] != self.embedding_dim: + raise ValueError( + f"Expected last dimension of input tensor ({z.shape[-1]}) to be embedding size of {self.embedding_dim}" + ) + + # Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2 + distances = torch.cdist(z, self.embedding, p=2.0) ** 2 + + # Encoding - select closest embedding vectors, shape [b, s] + token_ids = torch.argmin(distances, dim=1) + + # Quantize - shape [b, s, d] + quantized = self.lookup(token_ids) + + # Use exponential moving average to update the embedding instead of a codebook loss, + # as suggested by Oord et al. 2017 and Razavi et al. 2019. + if self.training and self.learnable: + self._ema_update_embedding(z, token_ids) + + # Straight through estimator + quantized = z + (quantized - z).detach() + + return quantized, token_ids + + def extra_repr(self) -> str: + return "num_embeddings={}, embedding_dim={}".format( + self.num_embeddings, self.embedding_dim + ) + + def lookup(self, token_ids: Tensor) -> Tensor: + # Returns the embeddings of shape [b, s, d] + return F.embedding(token_ids, self.embedding) From efdc2289c364ba2b461e2790c4bbe9040e567ae1 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Wed, 20 Nov 2024 16:31:53 -0800 Subject: [PATCH 2/4] bug fixes --- torchtune/modules/__init__.py | 2 ++ torchtune/modules/vq_embeddings.py | 32 ++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 8c4bed6e21..26dfb11afc 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -30,6 +30,7 @@ TransformerSelfAttentionLayer, ) from .vision_transformer import VisionTransformer +from .vq_embeddings import VectorQuantizedEmbeddings __all__ = [ "MultiHeadAttention", @@ -39,6 +40,7 @@ "KVCache", "RotaryPositionalEmbeddings", "VisionRotaryPositionalEmbeddings", + "VectorQuantizedEmbeddings", "RMSNorm", "TiedLinear", "Fp32LayerNorm", diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py index f36101e49e..9788aefe10 100644 --- a/torchtune/modules/vq_embeddings.py +++ b/torchtune/modules/vq_embeddings.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Mapping, Optional, Tuple +from typing import Tuple import torch from torch import nn, Tensor @@ -129,29 +129,39 @@ def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: b is batch size, s is sequence length or time, and d is ``embedding_dim``. Returns: - An instance of :class:`~torchmultimodal.modules.layers.codebook.CodebookOutput`. + Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used. + + Raises: + ValueError: if input embedding dimension does not match embedding dimension of module """ - if z.shape[-1] != self.embedding_dim: + bsz, seq_len, z_embed_dim = z.shape + if z_embed_dim != self.embedding_dim: raise ValueError( - f"Expected last dimension of input tensor ({z.shape[-1]}) to be embedding size of {self.embedding_dim}" + f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}" ) + # Flatten into batch dimension + z_flat = z.view(-1, z_embed_dim) # Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2 - distances = torch.cdist(z, self.embedding, p=2.0) ** 2 + distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2 - # Encoding - select closest embedding vectors, shape [b, s] - token_ids = torch.argmin(distances, dim=1) + # Encoding - select closest embedding vectors, shape [b * s, ] + token_ids_flat = torch.argmin(distances, dim=1) - # Quantize - shape [b, s, d] - quantized = self.lookup(token_ids) + # Quantize - shape [b * s, d] + quantized_flat = self.lookup(token_ids_flat) # Use exponential moving average to update the embedding instead of a codebook loss, # as suggested by Oord et al. 2017 and Razavi et al. 2019. if self.training and self.learnable: - self._ema_update_embedding(z, token_ids) + self._ema_update_embedding(z_flat, token_ids_flat) # Straight through estimator - quantized = z + (quantized - z).detach() + quantized_flat = z_flat + (quantized_flat - z_flat).detach() + + # Reshape to original - [b, s, d] and [b, s] + quantized = quantized_flat.view(bsz, seq_len, z_embed_dim) + token_ids = token_ids_flat.view(bsz, seq_len) return quantized, token_ids From 5f5dcf644edcbee2278796c3656354d0254c78de Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Fri, 22 Nov 2024 18:01:43 -0800 Subject: [PATCH 3/4] fix tests --- tests/torchtune/modules/test_vq_embeddings.py | 135 +++++++++--------- torchtune/modules/vq_embeddings.py | 4 +- 2 files changed, 65 insertions(+), 74 deletions(-) diff --git a/tests/torchtune/modules/test_vq_embeddings.py b/tests/torchtune/modules/test_vq_embeddings.py index 317c01574b..825bbd6cd9 100644 --- a/tests/torchtune/modules/test_vq_embeddings.py +++ b/tests/torchtune/modules/test_vq_embeddings.py @@ -24,26 +24,34 @@ def num_embeddings(self): @pytest.fixture def embedding_dim(self): - return 3 + return 5 + + @pytest.fixture + def codebook(self, num_embeddings, embedding_dim): + def vq(learnable): + return VectorQuantizedEmbeddings( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + decay=0.3, + learnable=learnable, + ) + + return vq @pytest.fixture def encoded(self): - # This is 2x5x3 + # This is 2x3x5 encoded = tensor( [ [ - [-1.0, 0.0, 1.0], - [2.0, 1.0, 0.0], - [0.0, -1.0, -1.0], - [0.0, 2.0, -1.0], - [-2.0, -1.0, 1.0], + [-1.0, 2.0, 0.0, 0.0, -2.0], + [0.0, 1.0, -1.0, 2.0, -1.0], + [1.0, 0.0, -1.0, -1.0, 1.0], ], [ - [2.0, 2.0, -1.0], - [1.0, -1.0, -2.0], - [0.0, 0.0, 0.0], - [1.0, 2.0, 1.0], - [1.0, 0.0, 0.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [2.0, -1.0, 0.0, 2.0, 0.0], + [-1.0, -2.0, 0.0, 1.0, 0.0], ], ] ) @@ -53,51 +61,32 @@ def encoded(self): @pytest.fixture def embedding_weights(self): - # This is 4x3 + # This is 4x5 return tensor( [ - [1.0, 0.0, -1.0], - [2.0, -2.0, 0.0], - [2.0, 1.0, 0.0], - [-1.0, -2.0, 0.0], + [1.0, 0.0, -1.0, -1.0, 2.0], + [2.0, -2.0, 0.0, 0.0, 1.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [-1.0, -2.0, 0.0, 2.0, 0.0], ] ) - @pytest.fixture - def input_tensor_flat(self): - # This is 4x3 - return tensor( - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - ) + def test_quantized_output(self, codebook, encoded, embedding_weights): + vq = codebook(learnable=False) + vq.embedding = embedding_weights + actual = vq(encoded) - @pytest.fixture - def codebook(self, num_embeddings, embedding_dim): - return VectorQuantizedEmbeddings( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - decay=0.3, - ) - - def test_quantized_output(self, codebook, embedding_weights, encoded): - codebook.embedding = embedding_weights - actual = codebook(encoded) - - # This is shape (2,5,3) expected_quantized = tensor( [ [ - [2.0, 2.0, 1.0], - [1.0, 1.0, 0.0], - [0.0, 0.0, -1.0], - [1.0, 1.0, -1.0], - [1.0, 1.0, 2.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, -1.0, -1.0, 2.0], ], [ - [2.0, 2.0, -1.0], - [1.0, -2.0, -2.0], - [0.0, 0.0, 0.0], - [1.0, 0.0, 2.0], - [1.0, 1.0, 0.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [2.0, -2.0, 0.0, 0.0, 1.0], + [-1.0, -2.0, 0.0, 2.0, 0.0], ], ] ) @@ -108,54 +97,57 @@ def test_quantized_output(self, codebook, embedding_weights, encoded): assert_expected(actual[0], expected_quantized) assert_expected(actual[1], expected_token_ids) - def test_ema_update_embedding(self, num_embeddings, embedding_dim, encoded): - codebook = VectorQuantizedEmbeddings( - num_embeddings, embedding_dim, learnable=True - ) - distances = torch.cdist(encoded, codebook.embedding, p=2.0) ** 2 + def test_ema_update_embedding(self, codebook, encoded, embedding_weights): + vq = codebook(learnable=True) + vq.embedding = embedding_weights + encoded_flat = encoded.view(-1, encoded.shape[-1]) + distances = torch.cdist(encoded_flat, vq.embedding, p=2.0) ** 2 codebook_indices = torch.argmin(distances, dim=1) - codebook._ema_update_embedding(encoded, codebook_indices) + vq._ema_update_embedding(encoded_flat, codebook_indices) - actual_weight = codebook.embedding + actual_weight = vq.embedding expected_weight = tensor( [ - [0.7647, -1.4118, 0.0000, 1.5882, 0.0000], + [2.0000, -1.0000, 0.0000, 2.0000, 0.0000], [2.0000, 1.0000, 0.0000, 1.0000, 1.0000], - [-0.4118, 1.4118, -0.5882, 1.1765, -1.4118], + [0.5647, 1.3760, -0.3936, 1.1213, -0.7635], [1.0000, 0.0000, -1.0000, -1.0000, 1.0000], ] ) assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4) - actual_code_avg = codebook.code_avg + actual_code_avg = vq.code_avg expected_code_avg = tensor( [ - [1.3000, -2.4000, 0.0000, 2.7000, 0.0000], - [2.0000, 1.0000, 0.0000, 1.0000, 1.0000], - [-0.7000, 2.4000, -1.0000, 2.0000, -2.4000], - [1.0000, 0.0000, -1.0000, -1.0000, 1.0000], + [0.4176, 0.3790, -0.7551, -0.6548, 1.0419], + [1.3309, -0.3437, 0.2303, 1.1865, 0.1305], + [1.1859, 2.8897, -0.8265, 2.3547, -1.6033], + [-0.9834, -0.7490, -0.3521, 0.5825, 0.4301], ] ) assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4) - actual_code_usage = codebook.code_usage - expected_code_usage = tensor([1.7000, 1.0000, 1.7000, 1.0000]) + actual_code_usage = vq.code_usage + expected_code_usage = tensor([0.7000, 0.7000, 2.1000, 0.7000]) assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4) - def test_codebook_restart(self, codebook, encoded): + def test_codebook_restart(self, codebook, encoded, embedding_weights): + vq = codebook(learnable=True) + vq.embedding = embedding_weights # Use only embedding vector at index = 1 and force restarts. # Slightly modify encoded_flat to make sure vectors restart to something new - encoded_noise = encoded + torch.randn_like(encoded) - codebook_indices_low_usage = torch.ones(encoded.shape[0], dtype=torch.long) - codebook._ema_update_embedding(encoded_noise, codebook_indices_low_usage) + encoded_flat = encoded.view(-1, encoded.shape[-1]) + encoded_noise = encoded_flat + torch.randn_like(encoded_flat) + codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long) + vq._ema_update_embedding(encoded_noise, codebook_indices_low_usage) # Check if embedding contains restarts - for i, emb in enumerate(codebook.embedding): + for i, emb in enumerate(vq.embedding): # We used only emb vector with index = 1, so check it was not restarted if i == 1: assert_expected( emb, - codebook.code_avg[1] / codebook.code_usage[1], + vq.code_avg[1] / vq.code_usage[1], rtol=0, atol=1e-4, ) @@ -170,11 +162,12 @@ def test_codebook_restart(self, codebook, encoded): ), "embedding restarted from encoder output incorrectly" def test_lookup(self, codebook, embedding_weights): - codebook.embedding = embedding_weights + vq = codebook(learnable=False) + vq.embedding = embedding_weights indices_flat = tensor([[0, 1]]) # (b, seq_len) indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape) - actual_quantized_flat = codebook.lookup(indices_flat) - actual_quantized = codebook.lookup(indices_shaped) + actual_quantized_flat = vq.lookup(indices_flat) + actual_quantized = vq.lookup(indices_shaped) expected_quantized_flat = tensor( [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]] ) diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py index 9788aefe10..12a4918f73 100644 --- a/torchtune/modules/vq_embeddings.py +++ b/torchtune/modules/vq_embeddings.py @@ -50,12 +50,10 @@ def __init__( # will not be updated by the optimizer but are still model parameters. # code_usage and code_avg correspond with N and m, respectively, from Oord et al. randn_init_embedding = torch.randn(num_embeddings, embedding_dim) + self.register_buffer("embedding", randn_init_embedding.clone()) if learnable: - self.register_buffer("embedding", randn_init_embedding.clone()) self.register_buffer("code_usage", torch.zeros(num_embeddings)) self.register_buffer("code_avg", randn_init_embedding.clone()) - else: - self.register_parameter("embedding", nn.Parameter(randn_init_embedding)) self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings From 1a0841a1b4392e47e9ec954992641d3dbea11b40 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 2 Dec 2024 12:50:26 -0800 Subject: [PATCH 4/4] remove EMA training --- tests/torchtune/modules/test_vq_embeddings.py | 115 ++++-------------- torchtune/modules/vq_embeddings.py | 99 +-------------- 2 files changed, 28 insertions(+), 186 deletions(-) diff --git a/tests/torchtune/modules/test_vq_embeddings.py b/tests/torchtune/modules/test_vq_embeddings.py index 825bbd6cd9..b8c1e83286 100644 --- a/tests/torchtune/modules/test_vq_embeddings.py +++ b/tests/torchtune/modules/test_vq_embeddings.py @@ -27,15 +27,24 @@ def embedding_dim(self): return 5 @pytest.fixture - def codebook(self, num_embeddings, embedding_dim): - def vq(learnable): - return VectorQuantizedEmbeddings( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - decay=0.3, - learnable=learnable, - ) + def embedding_weights(self): + # This is 4x5 + return tensor( + [ + [1.0, 0.0, -1.0, -1.0, 2.0], + [2.0, -2.0, 0.0, 0.0, 1.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [-1.0, -2.0, 0.0, 2.0, 0.0], + ] + ) + @pytest.fixture + def codebook(self, num_embeddings, embedding_dim, embedding_weights): + vq = VectorQuantizedEmbeddings( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + vq.embedding.data = embedding_weights return vq @pytest.fixture @@ -59,22 +68,8 @@ def encoded(self): return encoded - @pytest.fixture - def embedding_weights(self): - # This is 4x5 - return tensor( - [ - [1.0, 0.0, -1.0, -1.0, 2.0], - [2.0, -2.0, 0.0, 0.0, 1.0], - [2.0, 1.0, 0.0, 1.0, 1.0], - [-1.0, -2.0, 0.0, 2.0, 0.0], - ] - ) - - def test_quantized_output(self, codebook, encoded, embedding_weights): - vq = codebook(learnable=False) - vq.embedding = embedding_weights - actual = vq(encoded) + def test_quantized_output(self, codebook, encoded): + actual = codebook(encoded) expected_quantized = tensor( [ @@ -97,77 +92,11 @@ def test_quantized_output(self, codebook, encoded, embedding_weights): assert_expected(actual[0], expected_quantized) assert_expected(actual[1], expected_token_ids) - def test_ema_update_embedding(self, codebook, encoded, embedding_weights): - vq = codebook(learnable=True) - vq.embedding = embedding_weights - encoded_flat = encoded.view(-1, encoded.shape[-1]) - distances = torch.cdist(encoded_flat, vq.embedding, p=2.0) ** 2 - codebook_indices = torch.argmin(distances, dim=1) - vq._ema_update_embedding(encoded_flat, codebook_indices) - - actual_weight = vq.embedding - expected_weight = tensor( - [ - [2.0000, -1.0000, 0.0000, 2.0000, 0.0000], - [2.0000, 1.0000, 0.0000, 1.0000, 1.0000], - [0.5647, 1.3760, -0.3936, 1.1213, -0.7635], - [1.0000, 0.0000, -1.0000, -1.0000, 1.0000], - ] - ) - assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4) - - actual_code_avg = vq.code_avg - expected_code_avg = tensor( - [ - [0.4176, 0.3790, -0.7551, -0.6548, 1.0419], - [1.3309, -0.3437, 0.2303, 1.1865, 0.1305], - [1.1859, 2.8897, -0.8265, 2.3547, -1.6033], - [-0.9834, -0.7490, -0.3521, 0.5825, 0.4301], - ] - ) - assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4) - - actual_code_usage = vq.code_usage - expected_code_usage = tensor([0.7000, 0.7000, 2.1000, 0.7000]) - assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4) - - def test_codebook_restart(self, codebook, encoded, embedding_weights): - vq = codebook(learnable=True) - vq.embedding = embedding_weights - # Use only embedding vector at index = 1 and force restarts. - # Slightly modify encoded_flat to make sure vectors restart to something new - encoded_flat = encoded.view(-1, encoded.shape[-1]) - encoded_noise = encoded_flat + torch.randn_like(encoded_flat) - codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long) - vq._ema_update_embedding(encoded_noise, codebook_indices_low_usage) - - # Check if embedding contains restarts - for i, emb in enumerate(vq.embedding): - # We used only emb vector with index = 1, so check it was not restarted - if i == 1: - assert_expected( - emb, - vq.code_avg[1] / vq.code_usage[1], - rtol=0, - atol=1e-4, - ) - # Compare each embedding vector to each encoded vector. - # If at least one match, then restart happened. - else: - assert any( - [ - torch.isclose(emb, enc, rtol=0, atol=1e-4).all() - for enc in encoded_noise - ] - ), "embedding restarted from encoder output incorrectly" - - def test_lookup(self, codebook, embedding_weights): - vq = codebook(learnable=False) - vq.embedding = embedding_weights + def test_decode(self, codebook): indices_flat = tensor([[0, 1]]) # (b, seq_len) indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape) - actual_quantized_flat = vq.lookup(indices_flat) - actual_quantized = vq.lookup(indices_shaped) + actual_quantized_flat = codebook.decode(indices_flat) + actual_quantized = codebook.decode(indices_shaped) expected_quantized_flat = tensor( [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]] ) diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py index 12a4918f73..14d6cef995 100644 --- a/torchtune/modules/vq_embeddings.py +++ b/torchtune/modules/vq_embeddings.py @@ -17,8 +17,8 @@ class VectorQuantizedEmbeddings(nn.Module): and performs a nearest-neighbor lookup in the embedding space. Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) to generate high-fidelity images, videos, and audio data. - The embedding weights are trained with exponential moving average updates as described - in original paper. + + This module currently does not support pre-training of the embeddings via EMA. Code was adapted from torchmultimodal's `Codebook module `_. @@ -26,99 +26,17 @@ class VectorQuantizedEmbeddings(nn.Module): Args: num_embeddings (int): Number of vectors in the embedding space. embedding_dim (int): Dimensionality of the embedding vectors. - decay (float, optional): Factor used in exponential moving average update of the embeddings. - Defaults to ``0.99``. - codebook_usage_threshold (float, optional): Threshold for the average number of times an embedding vector - is chosen below which it will be re-initialized. Defaults to ``1.0``. - learnable (bool): If True, register embedding weights, codebook usage, and codebook average to buffer - for EMA updates during training. If False, only register embedding weights as an nn.Parameter, for use - in a frozen module. Default is False. - epsilon (float, optional): Noise used in Laplace smoothing of codebook usage. Defaults to ``1e-7``. """ def __init__( self, num_embeddings: int, embedding_dim: int, - decay: float = 0.99, - codebook_usage_threshold: float = 1.0, - learnable: bool = False, - epsilon: float = 1e-7, ) -> None: super().__init__() - # Embedding weights and parameters for EMA update will be registered to buffer, as they - # will not be updated by the optimizer but are still model parameters. - # code_usage and code_avg correspond with N and m, respectively, from Oord et al. - randn_init_embedding = torch.randn(num_embeddings, embedding_dim) - self.register_buffer("embedding", randn_init_embedding.clone()) - if learnable: - self.register_buffer("code_usage", torch.zeros(num_embeddings)) - self.register_buffer("code_avg", randn_init_embedding.clone()) - - self.embedding_dim = embedding_dim + self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) self.num_embeddings = num_embeddings - self.learnable = learnable - - self._decay = decay - # Used in Laplace smoothing of code usage - self._epsilon = epsilon - - # Threshold for randomly reseting unused embedding vectors - self.codebook_usage_threshold = codebook_usage_threshold - - def _tile(self, x: Tensor, n: int) -> Tensor: - # Repeat vectors in x if x has less than n vectors - num_vectors, num_channels = x.shape - if num_vectors < n: - num_repeats = (n + num_vectors - 1) // num_vectors - # Add a small amount of noise to repeated vectors - std = 0.01 / torch.sqrt(torch.tensor(num_channels)) - x = x.repeat(num_repeats, 1) - x = x + torch.randn_like(x) * std - return x - - def _get_random_vectors(self, x: Tensor, n: int) -> Tensor: - # Gets n random row vectors from 2D tensor x - x_tiled = self._tile(x, n) - idx = torch.randperm(x_tiled.shape[0]) - x_rand = x_tiled[idx][:n] - return x_rand - - def _ema_update_embedding(self, z: Tensor, codebook_indices: Tensor) -> None: - # Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average - # of the encoder output. However, we can't compute this in minibatches, so we - # must use exponential moving average. - - # Convert indices to one hot encoding - codebook_onehot = nn.functional.one_hot( - codebook_indices, num_classes=self.num_embeddings - ).type(torch.float) - # Count how often each embedding vector was looked up - codebook_selection_count = torch.sum(codebook_onehot, 0) - # Update usage value for each embedding vector - self.code_usage.mul_(self._decay).add_( - codebook_selection_count, alpha=(1 - self._decay) - ) - # Laplace smoothing of codebook usage - to prevent zero counts - n = torch.sum(self.code_usage) - self.code_usage.add_(self._epsilon).divide_( - n + self.num_embeddings * self._epsilon - ).mul_(n) - # Get all encoded vectors attracted to each embedding vector - encoded_per_codebook = torch.matmul(codebook_onehot.t(), z) - # Update each embedding vector with new encoded vectors that are attracted to it, - # divided by its usage to yield the mean of encoded vectors that choose it - self.code_avg.mul_(self._decay).add_( - encoded_per_codebook, alpha=(1 - self._decay) - ) - self.embedding = self.code_avg / self.code_usage.unsqueeze(1) - # Reset any embedding vectors that fall below threshold usage with random encoded vectors - z_rand = self._get_random_vectors(z, self.num_embeddings) - self.embedding = torch.where( - self.code_usage.unsqueeze(1) >= self.codebook_usage_threshold, - self.embedding, - z_rand, - ) + self.embedding_dim = embedding_dim def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: """ @@ -147,12 +65,7 @@ def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: token_ids_flat = torch.argmin(distances, dim=1) # Quantize - shape [b * s, d] - quantized_flat = self.lookup(token_ids_flat) - - # Use exponential moving average to update the embedding instead of a codebook loss, - # as suggested by Oord et al. 2017 and Razavi et al. 2019. - if self.training and self.learnable: - self._ema_update_embedding(z_flat, token_ids_flat) + quantized_flat = self.decode(token_ids_flat) # Straight through estimator quantized_flat = z_flat + (quantized_flat - z_flat).detach() @@ -168,6 +81,6 @@ def extra_repr(self) -> str: self.num_embeddings, self.embedding_dim ) - def lookup(self, token_ids: Tensor) -> Tensor: + def decode(self, token_ids: Tensor) -> Tensor: # Returns the embeddings of shape [b, s, d] return F.embedding(token_ids, self.embedding)