diff --git a/tests/torchtune/modules/test_vq_embeddings.py b/tests/torchtune/modules/test_vq_embeddings.py new file mode 100644 index 0000000000..b8c1e83286 --- /dev/null +++ b/tests/torchtune/modules/test_vq_embeddings.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +import torch +from tests.test_utils import assert_expected +from torch import tensor +from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings + + +@pytest.fixture(autouse=True) +def random_seed(): + torch.manual_seed(4) + + +class TestVectorQuantizedEmbeddings: + @pytest.fixture + def num_embeddings(self): + return 4 + + @pytest.fixture + def embedding_dim(self): + return 5 + + @pytest.fixture + def embedding_weights(self): + # This is 4x5 + return tensor( + [ + [1.0, 0.0, -1.0, -1.0, 2.0], + [2.0, -2.0, 0.0, 0.0, 1.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [-1.0, -2.0, 0.0, 2.0, 0.0], + ] + ) + + @pytest.fixture + def codebook(self, num_embeddings, embedding_dim, embedding_weights): + vq = VectorQuantizedEmbeddings( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + vq.embedding.data = embedding_weights + return vq + + @pytest.fixture + def encoded(self): + # This is 2x3x5 + encoded = tensor( + [ + [ + [-1.0, 2.0, 0.0, 0.0, -2.0], + [0.0, 1.0, -1.0, 2.0, -1.0], + [1.0, 0.0, -1.0, -1.0, 1.0], + ], + [ + [2.0, 1.0, 0.0, 1.0, 1.0], + [2.0, -1.0, 0.0, 2.0, 0.0], + [-1.0, -2.0, 0.0, 1.0, 0.0], + ], + ] + ) + encoded.requires_grad_() + + return encoded + + def test_quantized_output(self, codebook, encoded): + actual = codebook(encoded) + + expected_quantized = tensor( + [ + [ + [2.0, 1.0, 0.0, 1.0, 1.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, -1.0, -1.0, 2.0], + ], + [ + [2.0, 1.0, 0.0, 1.0, 1.0], + [2.0, -2.0, 0.0, 0.0, 1.0], + [-1.0, -2.0, 0.0, 2.0, 0.0], + ], + ] + ) + expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type( + torch.LongTensor + ) + + assert_expected(actual[0], expected_quantized) + assert_expected(actual[1], expected_token_ids) + + def test_decode(self, codebook): + indices_flat = tensor([[0, 1]]) # (b, seq_len) + indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape) + actual_quantized_flat = codebook.decode(indices_flat) + actual_quantized = codebook.decode(indices_shaped) + expected_quantized_flat = tensor( + [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]] + ) + expected_quantized = tensor( + [ + [ + [[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]], + [[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]], + ] + ] + ) + assert_expected( + actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4 + ) + assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4) diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 29c014c33e..338c9ef53e 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -29,6 +29,7 @@ TransformerSelfAttentionLayer, ) from .vision_transformer import VisionTransformer +from .vq_embeddings import VectorQuantizedEmbeddings __all__ = [ "MultiHeadAttention", @@ -38,6 +39,7 @@ "KVCache", "RotaryPositionalEmbeddings", "VisionRotaryPositionalEmbeddings", + "VectorQuantizedEmbeddings", "RMSNorm", "TiedLinear", "Fp32LayerNorm", diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py new file mode 100644 index 0000000000..14d6cef995 --- /dev/null +++ b/torchtune/modules/vq_embeddings.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + + +class VectorQuantizedEmbeddings(nn.Module): + """ + Vector quantized embedding layer that takes in the output of an encoder + and performs a nearest-neighbor lookup in the embedding space. + Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) + to generate high-fidelity images, videos, and audio data. + + This module currently does not support pre-training of the embeddings via EMA. + + Code was adapted from torchmultimodal's `Codebook module + `_. + + Args: + num_embeddings (int): Number of vectors in the embedding space. + embedding_dim (int): Dimensionality of the embedding vectors. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + ) -> None: + super().__init__() + self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where + b is batch size, s is sequence length or time, and d is ``embedding_dim``. + + Returns: + Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used. + + Raises: + ValueError: if input embedding dimension does not match embedding dimension of module + """ + bsz, seq_len, z_embed_dim = z.shape + if z_embed_dim != self.embedding_dim: + raise ValueError( + f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}" + ) + + # Flatten into batch dimension + z_flat = z.view(-1, z_embed_dim) + # Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2 + distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2 + + # Encoding - select closest embedding vectors, shape [b * s, ] + token_ids_flat = torch.argmin(distances, dim=1) + + # Quantize - shape [b * s, d] + quantized_flat = self.decode(token_ids_flat) + + # Straight through estimator + quantized_flat = z_flat + (quantized_flat - z_flat).detach() + + # Reshape to original - [b, s, d] and [b, s] + quantized = quantized_flat.view(bsz, seq_len, z_embed_dim) + token_ids = token_ids_flat.view(bsz, seq_len) + + return quantized, token_ids + + def extra_repr(self) -> str: + return "num_embeddings={}, embedding_dim={}".format( + self.num_embeddings, self.embedding_dim + ) + + def decode(self, token_ids: Tensor) -> Tensor: + # Returns the embeddings of shape [b, s, d] + return F.embedding(token_ids, self.embedding)