Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions tests/torchtune/modules/test_vq_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected
from torch import tensor
from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings


@pytest.fixture(autouse=True)
def random_seed():
torch.manual_seed(4)


class TestVectorQuantizedEmbeddings:
@pytest.fixture
def num_embeddings(self):
return 4

@pytest.fixture
def embedding_dim(self):
return 5

@pytest.fixture
def codebook(self, num_embeddings, embedding_dim):
def vq(learnable):
return VectorQuantizedEmbeddings(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
decay=0.3,
learnable=learnable,
)

return vq

@pytest.fixture
def encoded(self):
# This is 2x3x5
encoded = tensor(
[
[
[-1.0, 2.0, 0.0, 0.0, -2.0],
[0.0, 1.0, -1.0, 2.0, -1.0],
[1.0, 0.0, -1.0, -1.0, 1.0],
],
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, -1.0, 0.0, 2.0, 0.0],
[-1.0, -2.0, 0.0, 1.0, 0.0],
],
]
)
encoded.requires_grad_()

return encoded

@pytest.fixture
def embedding_weights(self):
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)

def test_quantized_output(self, codebook, encoded, embedding_weights):
vq = codebook(learnable=False)
vq.embedding = embedding_weights
actual = vq(encoded)

expected_quantized = tensor(
[
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[1.0, 0.0, -1.0, -1.0, 2.0],
],
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
],
]
)
expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
torch.LongTensor
)

assert_expected(actual[0], expected_quantized)
assert_expected(actual[1], expected_token_ids)

def test_ema_update_embedding(self, codebook, encoded, embedding_weights):
vq = codebook(learnable=True)
vq.embedding = embedding_weights
encoded_flat = encoded.view(-1, encoded.shape[-1])
distances = torch.cdist(encoded_flat, vq.embedding, p=2.0) ** 2
codebook_indices = torch.argmin(distances, dim=1)
vq._ema_update_embedding(encoded_flat, codebook_indices)

actual_weight = vq.embedding
expected_weight = tensor(
[
[2.0000, -1.0000, 0.0000, 2.0000, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[0.5647, 1.3760, -0.3936, 1.1213, -0.7635],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4)

actual_code_avg = vq.code_avg
expected_code_avg = tensor(
[
[0.4176, 0.3790, -0.7551, -0.6548, 1.0419],
[1.3309, -0.3437, 0.2303, 1.1865, 0.1305],
[1.1859, 2.8897, -0.8265, 2.3547, -1.6033],
[-0.9834, -0.7490, -0.3521, 0.5825, 0.4301],
]
)
assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4)

actual_code_usage = vq.code_usage
expected_code_usage = tensor([0.7000, 0.7000, 2.1000, 0.7000])
assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4)

def test_codebook_restart(self, codebook, encoded, embedding_weights):
vq = codebook(learnable=True)
vq.embedding = embedding_weights
# Use only embedding vector at index = 1 and force restarts.
# Slightly modify encoded_flat to make sure vectors restart to something new
encoded_flat = encoded.view(-1, encoded.shape[-1])
encoded_noise = encoded_flat + torch.randn_like(encoded_flat)
codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long)
vq._ema_update_embedding(encoded_noise, codebook_indices_low_usage)

# Check if embedding contains restarts
for i, emb in enumerate(vq.embedding):
# We used only emb vector with index = 1, so check it was not restarted
if i == 1:
assert_expected(
emb,
vq.code_avg[1] / vq.code_usage[1],
rtol=0,
atol=1e-4,
)
# Compare each embedding vector to each encoded vector.
# If at least one match, then restart happened.
else:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=1e-4).all()
for enc in encoded_noise
]
), "embedding restarted from encoder output incorrectly"

def test_lookup(self, codebook, embedding_weights):
vq = codebook(learnable=False)
vq.embedding = embedding_weights
indices_flat = tensor([[0, 1]]) # (b, seq_len)
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
actual_quantized_flat = vq.lookup(indices_flat)
actual_quantized = vq.lookup(indices_shaped)
expected_quantized_flat = tensor(
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
)
expected_quantized = tensor(
[
[
[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
[[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
]
]
)
assert_expected(
actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
)
assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)
2 changes: 2 additions & 0 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TransformerSelfAttentionLayer,
)
from .vision_transformer import VisionTransformer
from .vq_embeddings import VectorQuantizedEmbeddings

__all__ = [
"MultiHeadAttention",
Expand All @@ -39,6 +40,7 @@
"KVCache",
"RotaryPositionalEmbeddings",
"VisionRotaryPositionalEmbeddings",
"VectorQuantizedEmbeddings",
"RMSNorm",
"TiedLinear",
"Fp32LayerNorm",
Expand Down
173 changes: 173 additions & 0 deletions torchtune/modules/vq_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import functional as F


class VectorQuantizedEmbeddings(nn.Module):
"""
Vector quantized embedding layer that takes in the output of an encoder
and performs a nearest-neighbor lookup in the embedding space.
Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.
The embedding weights are trained with exponential moving average updates as described
in original paper.

Code was adapted from torchmultimodal's `Codebook module
<https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/modules/layers/codebook.py>`_.

Args:
num_embeddings (int): Number of vectors in the embedding space.
embedding_dim (int): Dimensionality of the embedding vectors.
decay (float, optional): Factor used in exponential moving average update of the embeddings.
Defaults to ``0.99``.
codebook_usage_threshold (float, optional): Threshold for the average number of times an embedding vector
is chosen below which it will be re-initialized. Defaults to ``1.0``.
learnable (bool): If True, register embedding weights, codebook usage, and codebook average to buffer
for EMA updates during training. If False, only register embedding weights as an nn.Parameter, for use
in a frozen module. Default is False.
epsilon (float, optional): Noise used in Laplace smoothing of codebook usage. Defaults to ``1e-7``.
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
decay: float = 0.99,
codebook_usage_threshold: float = 1.0,
learnable: bool = False,
epsilon: float = 1e-7,
) -> None:
super().__init__()
# Embedding weights and parameters for EMA update will be registered to buffer, as they
# will not be updated by the optimizer but are still model parameters.
# code_usage and code_avg correspond with N and m, respectively, from Oord et al.
randn_init_embedding = torch.randn(num_embeddings, embedding_dim)
self.register_buffer("embedding", randn_init_embedding.clone())
if learnable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this variable because it could disagree with the nn.module.training attribute. Is this only needed for EMA? Why wouldn't the ema code handle this?

self.register_buffer("code_usage", torch.zeros(num_embeddings))
self.register_buffer("code_avg", randn_init_embedding.clone())

self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.learnable = learnable

self._decay = decay
# Used in Laplace smoothing of code usage
self._epsilon = epsilon

# Threshold for randomly reseting unused embedding vectors
self.codebook_usage_threshold = codebook_usage_threshold

def _tile(self, x: Tensor, n: int) -> Tensor:
# Repeat vectors in x if x has less than n vectors
num_vectors, num_channels = x.shape
if num_vectors < n:
num_repeats = (n + num_vectors - 1) // num_vectors
# Add a small amount of noise to repeated vectors
std = 0.01 / torch.sqrt(torch.tensor(num_channels))
x = x.repeat(num_repeats, 1)
x = x + torch.randn_like(x) * std
return x

def _get_random_vectors(self, x: Tensor, n: int) -> Tensor:
# Gets n random row vectors from 2D tensor x
x_tiled = self._tile(x, n)
idx = torch.randperm(x_tiled.shape[0])
x_rand = x_tiled[idx][:n]
return x_rand

def _ema_update_embedding(self, z: Tensor, codebook_indices: Tensor) -> None:
# Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average
# of the encoder output. However, we can't compute this in minibatches, so we
# must use exponential moving average.

# Convert indices to one hot encoding
codebook_onehot = nn.functional.one_hot(
codebook_indices, num_classes=self.num_embeddings
).type(torch.float)
# Count how often each embedding vector was looked up
codebook_selection_count = torch.sum(codebook_onehot, 0)
# Update usage value for each embedding vector
self.code_usage.mul_(self._decay).add_(
codebook_selection_count, alpha=(1 - self._decay)
)
# Laplace smoothing of codebook usage - to prevent zero counts
n = torch.sum(self.code_usage)
self.code_usage.add_(self._epsilon).divide_(
n + self.num_embeddings * self._epsilon
).mul_(n)
# Get all encoded vectors attracted to each embedding vector
encoded_per_codebook = torch.matmul(codebook_onehot.t(), z)
# Update each embedding vector with new encoded vectors that are attracted to it,
# divided by its usage to yield the mean of encoded vectors that choose it
self.code_avg.mul_(self._decay).add_(
encoded_per_codebook, alpha=(1 - self._decay)
)
self.embedding = self.code_avg / self.code_usage.unsqueeze(1)
# Reset any embedding vectors that fall below threshold usage with random encoded vectors
z_rand = self._get_random_vectors(z, self.num_embeddings)
self.embedding = torch.where(
self.code_usage.unsqueeze(1) >= self.codebook_usage_threshold,
self.embedding,
z_rand,
)

def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where
b is batch size, s is sequence length or time, and d is ``embedding_dim``.

Returns:
Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used.

Raises:
ValueError: if input embedding dimension does not match embedding dimension of module
"""
bsz, seq_len, z_embed_dim = z.shape
if z_embed_dim != self.embedding_dim:
raise ValueError(
f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}"
)

# Flatten into batch dimension
z_flat = z.view(-1, z_embed_dim)
# Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2
distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2

# Encoding - select closest embedding vectors, shape [b * s, ]
token_ids_flat = torch.argmin(distances, dim=1)

# Quantize - shape [b * s, d]
quantized_flat = self.lookup(token_ids_flat)

# Use exponential moving average to update the embedding instead of a codebook loss,
# as suggested by Oord et al. 2017 and Razavi et al. 2019.
if self.training and self.learnable:
self._ema_update_embedding(z_flat, token_ids_flat)

# Straight through estimator
quantized_flat = z_flat + (quantized_flat - z_flat).detach()

# Reshape to original - [b, s, d] and [b, s]
quantized = quantized_flat.view(bsz, seq_len, z_embed_dim)
token_ids = token_ids_flat.view(bsz, seq_len)

return quantized, token_ids

def extra_repr(self) -> str:
return "num_embeddings={}, embedding_dim={}".format(
self.num_embeddings, self.embedding_dim
)

def lookup(self, token_ids: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want to use "lookup" here or something more consistent with the rest of the library like deocde?

# Returns the embeddings of shape [b, s, d]
return F.embedding(token_ids, self.embedding)
Loading