-
Notifications
You must be signed in to change notification settings - Fork 678
Vector Quantized Embeddings #2040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
|
||
import torch | ||
from tests.test_utils import assert_expected | ||
from torch import tensor | ||
from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random_seed(): | ||
torch.manual_seed(4) | ||
|
||
|
||
class TestVectorQuantizedEmbeddings: | ||
@pytest.fixture | ||
def num_embeddings(self): | ||
return 4 | ||
|
||
@pytest.fixture | ||
def embedding_dim(self): | ||
return 5 | ||
|
||
@pytest.fixture | ||
def codebook(self, num_embeddings, embedding_dim): | ||
def vq(learnable): | ||
return VectorQuantizedEmbeddings( | ||
num_embeddings=num_embeddings, | ||
embedding_dim=embedding_dim, | ||
decay=0.3, | ||
learnable=learnable, | ||
) | ||
|
||
return vq | ||
|
||
@pytest.fixture | ||
def encoded(self): | ||
# This is 2x3x5 | ||
encoded = tensor( | ||
[ | ||
[ | ||
[-1.0, 2.0, 0.0, 0.0, -2.0], | ||
[0.0, 1.0, -1.0, 2.0, -1.0], | ||
[1.0, 0.0, -1.0, -1.0, 1.0], | ||
], | ||
[ | ||
[2.0, 1.0, 0.0, 1.0, 1.0], | ||
[2.0, -1.0, 0.0, 2.0, 0.0], | ||
[-1.0, -2.0, 0.0, 1.0, 0.0], | ||
], | ||
] | ||
) | ||
encoded.requires_grad_() | ||
|
||
return encoded | ||
|
||
@pytest.fixture | ||
def embedding_weights(self): | ||
# This is 4x5 | ||
return tensor( | ||
[ | ||
[1.0, 0.0, -1.0, -1.0, 2.0], | ||
[2.0, -2.0, 0.0, 0.0, 1.0], | ||
[2.0, 1.0, 0.0, 1.0, 1.0], | ||
[-1.0, -2.0, 0.0, 2.0, 0.0], | ||
] | ||
) | ||
|
||
def test_quantized_output(self, codebook, encoded, embedding_weights): | ||
vq = codebook(learnable=False) | ||
vq.embedding = embedding_weights | ||
actual = vq(encoded) | ||
|
||
expected_quantized = tensor( | ||
[ | ||
[ | ||
[2.0, 1.0, 0.0, 1.0, 1.0], | ||
[2.0, 1.0, 0.0, 1.0, 1.0], | ||
[1.0, 0.0, -1.0, -1.0, 2.0], | ||
], | ||
[ | ||
[2.0, 1.0, 0.0, 1.0, 1.0], | ||
[2.0, -2.0, 0.0, 0.0, 1.0], | ||
[-1.0, -2.0, 0.0, 2.0, 0.0], | ||
], | ||
] | ||
) | ||
expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type( | ||
torch.LongTensor | ||
) | ||
|
||
assert_expected(actual[0], expected_quantized) | ||
assert_expected(actual[1], expected_token_ids) | ||
|
||
def test_ema_update_embedding(self, codebook, encoded, embedding_weights): | ||
vq = codebook(learnable=True) | ||
vq.embedding = embedding_weights | ||
encoded_flat = encoded.view(-1, encoded.shape[-1]) | ||
distances = torch.cdist(encoded_flat, vq.embedding, p=2.0) ** 2 | ||
codebook_indices = torch.argmin(distances, dim=1) | ||
vq._ema_update_embedding(encoded_flat, codebook_indices) | ||
|
||
actual_weight = vq.embedding | ||
expected_weight = tensor( | ||
[ | ||
[2.0000, -1.0000, 0.0000, 2.0000, 0.0000], | ||
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000], | ||
[0.5647, 1.3760, -0.3936, 1.1213, -0.7635], | ||
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000], | ||
] | ||
) | ||
assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4) | ||
|
||
actual_code_avg = vq.code_avg | ||
expected_code_avg = tensor( | ||
[ | ||
[0.4176, 0.3790, -0.7551, -0.6548, 1.0419], | ||
[1.3309, -0.3437, 0.2303, 1.1865, 0.1305], | ||
[1.1859, 2.8897, -0.8265, 2.3547, -1.6033], | ||
[-0.9834, -0.7490, -0.3521, 0.5825, 0.4301], | ||
] | ||
) | ||
assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4) | ||
|
||
actual_code_usage = vq.code_usage | ||
expected_code_usage = tensor([0.7000, 0.7000, 2.1000, 0.7000]) | ||
assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4) | ||
|
||
def test_codebook_restart(self, codebook, encoded, embedding_weights): | ||
vq = codebook(learnable=True) | ||
vq.embedding = embedding_weights | ||
# Use only embedding vector at index = 1 and force restarts. | ||
# Slightly modify encoded_flat to make sure vectors restart to something new | ||
encoded_flat = encoded.view(-1, encoded.shape[-1]) | ||
encoded_noise = encoded_flat + torch.randn_like(encoded_flat) | ||
codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long) | ||
vq._ema_update_embedding(encoded_noise, codebook_indices_low_usage) | ||
|
||
# Check if embedding contains restarts | ||
for i, emb in enumerate(vq.embedding): | ||
# We used only emb vector with index = 1, so check it was not restarted | ||
if i == 1: | ||
assert_expected( | ||
emb, | ||
vq.code_avg[1] / vq.code_usage[1], | ||
rtol=0, | ||
atol=1e-4, | ||
) | ||
# Compare each embedding vector to each encoded vector. | ||
# If at least one match, then restart happened. | ||
else: | ||
assert any( | ||
[ | ||
torch.isclose(emb, enc, rtol=0, atol=1e-4).all() | ||
for enc in encoded_noise | ||
] | ||
), "embedding restarted from encoder output incorrectly" | ||
|
||
def test_lookup(self, codebook, embedding_weights): | ||
vq = codebook(learnable=False) | ||
vq.embedding = embedding_weights | ||
indices_flat = tensor([[0, 1]]) # (b, seq_len) | ||
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape) | ||
actual_quantized_flat = vq.lookup(indices_flat) | ||
actual_quantized = vq.lookup(indices_shaped) | ||
expected_quantized_flat = tensor( | ||
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]] | ||
) | ||
expected_quantized = tensor( | ||
[ | ||
[ | ||
[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]], | ||
[[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]], | ||
] | ||
] | ||
) | ||
assert_expected( | ||
actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4 | ||
) | ||
assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
from torch.nn import functional as F | ||
|
||
|
||
class VectorQuantizedEmbeddings(nn.Module): | ||
""" | ||
Vector quantized embedding layer that takes in the output of an encoder | ||
and performs a nearest-neighbor lookup in the embedding space. | ||
Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) | ||
to generate high-fidelity images, videos, and audio data. | ||
The embedding weights are trained with exponential moving average updates as described | ||
in original paper. | ||
|
||
Code was adapted from torchmultimodal's `Codebook module | ||
<https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/modules/layers/codebook.py>`_. | ||
|
||
Args: | ||
num_embeddings (int): Number of vectors in the embedding space. | ||
embedding_dim (int): Dimensionality of the embedding vectors. | ||
decay (float, optional): Factor used in exponential moving average update of the embeddings. | ||
Defaults to ``0.99``. | ||
codebook_usage_threshold (float, optional): Threshold for the average number of times an embedding vector | ||
is chosen below which it will be re-initialized. Defaults to ``1.0``. | ||
learnable (bool): If True, register embedding weights, codebook usage, and codebook average to buffer | ||
for EMA updates during training. If False, only register embedding weights as an nn.Parameter, for use | ||
in a frozen module. Default is False. | ||
epsilon (float, optional): Noise used in Laplace smoothing of codebook usage. Defaults to ``1e-7``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_embeddings: int, | ||
embedding_dim: int, | ||
decay: float = 0.99, | ||
codebook_usage_threshold: float = 1.0, | ||
learnable: bool = False, | ||
epsilon: float = 1e-7, | ||
) -> None: | ||
super().__init__() | ||
# Embedding weights and parameters for EMA update will be registered to buffer, as they | ||
# will not be updated by the optimizer but are still model parameters. | ||
# code_usage and code_avg correspond with N and m, respectively, from Oord et al. | ||
randn_init_embedding = torch.randn(num_embeddings, embedding_dim) | ||
self.register_buffer("embedding", randn_init_embedding.clone()) | ||
if learnable: | ||
|
||
self.register_buffer("code_usage", torch.zeros(num_embeddings)) | ||
self.register_buffer("code_avg", randn_init_embedding.clone()) | ||
|
||
self.embedding_dim = embedding_dim | ||
self.num_embeddings = num_embeddings | ||
self.learnable = learnable | ||
|
||
self._decay = decay | ||
# Used in Laplace smoothing of code usage | ||
self._epsilon = epsilon | ||
|
||
# Threshold for randomly reseting unused embedding vectors | ||
self.codebook_usage_threshold = codebook_usage_threshold | ||
|
||
def _tile(self, x: Tensor, n: int) -> Tensor: | ||
# Repeat vectors in x if x has less than n vectors | ||
num_vectors, num_channels = x.shape | ||
if num_vectors < n: | ||
num_repeats = (n + num_vectors - 1) // num_vectors | ||
# Add a small amount of noise to repeated vectors | ||
std = 0.01 / torch.sqrt(torch.tensor(num_channels)) | ||
x = x.repeat(num_repeats, 1) | ||
x = x + torch.randn_like(x) * std | ||
return x | ||
|
||
def _get_random_vectors(self, x: Tensor, n: int) -> Tensor: | ||
# Gets n random row vectors from 2D tensor x | ||
x_tiled = self._tile(x, n) | ||
idx = torch.randperm(x_tiled.shape[0]) | ||
x_rand = x_tiled[idx][:n] | ||
return x_rand | ||
|
||
def _ema_update_embedding(self, z: Tensor, codebook_indices: Tensor) -> None: | ||
# Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average | ||
# of the encoder output. However, we can't compute this in minibatches, so we | ||
# must use exponential moving average. | ||
|
||
# Convert indices to one hot encoding | ||
codebook_onehot = nn.functional.one_hot( | ||
codebook_indices, num_classes=self.num_embeddings | ||
).type(torch.float) | ||
# Count how often each embedding vector was looked up | ||
codebook_selection_count = torch.sum(codebook_onehot, 0) | ||
# Update usage value for each embedding vector | ||
self.code_usage.mul_(self._decay).add_( | ||
codebook_selection_count, alpha=(1 - self._decay) | ||
) | ||
# Laplace smoothing of codebook usage - to prevent zero counts | ||
n = torch.sum(self.code_usage) | ||
self.code_usage.add_(self._epsilon).divide_( | ||
n + self.num_embeddings * self._epsilon | ||
).mul_(n) | ||
# Get all encoded vectors attracted to each embedding vector | ||
encoded_per_codebook = torch.matmul(codebook_onehot.t(), z) | ||
# Update each embedding vector with new encoded vectors that are attracted to it, | ||
# divided by its usage to yield the mean of encoded vectors that choose it | ||
self.code_avg.mul_(self._decay).add_( | ||
encoded_per_codebook, alpha=(1 - self._decay) | ||
) | ||
self.embedding = self.code_avg / self.code_usage.unsqueeze(1) | ||
# Reset any embedding vectors that fall below threshold usage with random encoded vectors | ||
z_rand = self._get_random_vectors(z, self.num_embeddings) | ||
self.embedding = torch.where( | ||
self.code_usage.unsqueeze(1) >= self.codebook_usage_threshold, | ||
self.embedding, | ||
z_rand, | ||
) | ||
|
||
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where | ||
b is batch size, s is sequence length or time, and d is ``embedding_dim``. | ||
|
||
Returns: | ||
Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used. | ||
|
||
Raises: | ||
ValueError: if input embedding dimension does not match embedding dimension of module | ||
""" | ||
bsz, seq_len, z_embed_dim = z.shape | ||
if z_embed_dim != self.embedding_dim: | ||
raise ValueError( | ||
f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}" | ||
) | ||
|
||
# Flatten into batch dimension | ||
z_flat = z.view(-1, z_embed_dim) | ||
# Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2 | ||
distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2 | ||
|
||
# Encoding - select closest embedding vectors, shape [b * s, ] | ||
token_ids_flat = torch.argmin(distances, dim=1) | ||
|
||
# Quantize - shape [b * s, d] | ||
quantized_flat = self.lookup(token_ids_flat) | ||
|
||
# Use exponential moving average to update the embedding instead of a codebook loss, | ||
# as suggested by Oord et al. 2017 and Razavi et al. 2019. | ||
if self.training and self.learnable: | ||
self._ema_update_embedding(z_flat, token_ids_flat) | ||
|
||
# Straight through estimator | ||
quantized_flat = z_flat + (quantized_flat - z_flat).detach() | ||
|
||
# Reshape to original - [b, s, d] and [b, s] | ||
quantized = quantized_flat.view(bsz, seq_len, z_embed_dim) | ||
token_ids = token_ids_flat.view(bsz, seq_len) | ||
|
||
return quantized, token_ids | ||
|
||
def extra_repr(self) -> str: | ||
return "num_embeddings={}, embedding_dim={}".format( | ||
self.num_embeddings, self.embedding_dim | ||
) | ||
|
||
def lookup(self, token_ids: Tensor) -> Tensor: | ||
|
||
# Returns the embeddings of shape [b, s, d] | ||
return F.embedding(token_ids, self.embedding) |
Uh oh!
There was an error while loading. Please reload this page.