Skip to content

Commit 279c5e9

Browse files
Merge pull request #5 from ritheshkumar95/tristan/vq-function
Vector Quantization as a Pytorch Function
2 parents cc9c562 + 7cf7fa1 commit 279c5e9

File tree

7 files changed

+167
-27
lines changed

7 files changed

+167
-27
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ nosetests.xml
4545
coverage.xml
4646
*.cover
4747
.hypothesis/
48+
.pytest_cache/
4849

4950
# Translations
5051
*.mo

functions.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
from torch.autograd import Function
3+
4+
class VectorQuantization(Function):
5+
@staticmethod
6+
def forward(ctx, inputs, codebook):
7+
with torch.no_grad():
8+
embedding_size = codebook.size(1)
9+
inputs_size = inputs.size()
10+
inputs_flatten = inputs.view(-1, embedding_size)
11+
12+
codebook_sqr = torch.sum(codebook ** 2, dim=1)
13+
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
14+
15+
# Compute the distances to the codebook
16+
distances = torch.addmm(codebook_sqr + inputs_sqr,
17+
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
18+
19+
_, indices_flatten = torch.min(distances, dim=1)
20+
indices = indices_flatten.view(*inputs_size[:-1])
21+
ctx.mark_non_differentiable(indices)
22+
23+
return indices
24+
25+
@staticmethod
26+
def backward(ctx, grad_output):
27+
raise RuntimeError('Trying to call `.grad()` on graph containing '
28+
'`VectorQuantization`. The function `VectorQuantization` '
29+
'is not differentiable. Use `VectorQuantizationStraightThrough` '
30+
'if you want a straight-through estimator of the gradient.')
31+
32+
class VectorQuantizationStraightThrough(Function):
33+
@staticmethod
34+
def forward(ctx, inputs, codebook):
35+
indices = vq(inputs, codebook)
36+
indices_flatten = indices.view(-1)
37+
ctx.save_for_backward(indices_flatten, codebook)
38+
ctx.mark_non_differentiable(indices_flatten)
39+
40+
codes_flatten = torch.index_select(codebook, dim=0,
41+
index=indices_flatten)
42+
codes = codes_flatten.view_as(inputs)
43+
44+
return (codes, indices_flatten)
45+
46+
@staticmethod
47+
def backward(ctx, grad_output, grad_indices):
48+
grad_inputs, grad_codebook = None, None
49+
50+
if ctx.needs_input_grad[0]:
51+
# Straight-through estimator
52+
grad_inputs = grad_output.clone()
53+
if ctx.needs_input_grad[1]:
54+
# Gradient wrt. the codebook
55+
indices, codebook = ctx.saved_tensors
56+
embedding_size = codebook.size(1)
57+
58+
grad_output_flatten = (grad_output.contiguous()
59+
.view(-1, embedding_size))
60+
grad_codebook = torch.zeros_like(codebook)
61+
grad_codebook.index_add_(0, indices, grad_output_flatten)
62+
63+
return (grad_inputs, grad_codebook)
64+
65+
vq = VectorQuantization.apply
66+
vq_st = VectorQuantizationStraightThrough.apply
67+
__all__ = [vq, vq_st]

miniimagenet_pixelcnn_prior.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def train(data_loader, model, prior, optimizer, args, writer):
1414
for images, labels in data_loader:
1515
with torch.no_grad():
1616
images = images.to(args.device)
17-
latents, _ = model.encode(images)
17+
latents = model.encode(images)
1818
latents = latents.detach()
1919

2020
labels = labels.to(args.device)
@@ -39,7 +39,7 @@ def test(data_loader, model, prior, args, writer):
3939
images = images.to(args.device)
4040
labels = labels.to(args.device)
4141

42-
latents, _ = model.encode(images)
42+
latents = model.encode(images)
4343
latents = latents.detach()
4444
logits = prior(latents, labels)
4545
logits = logits.permute(0, 2, 3, 1).contiguous()

miniimagenet_vqvae.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,16 @@ def train(data_loader, model, optimizer, args, writer):
1515

1616
optimizer.zero_grad()
1717
x_tilde, z_e_x, z_q_x = model(images)
18-
z_q_x.retain_grad()
1918

19+
# Reconstruction loss
2020
loss_recons = F.mse_loss(x_tilde, images)
21-
loss_recons.backward(retain_graph=True)
22-
23-
# Straight-through estimator
24-
z_e_x.backward(z_q_x.grad, retain_graph=True)
25-
2621
# Vector quantization objective
27-
model.codebook.embedding.zero_grad()
2822
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
29-
loss_vq.backward(retain_graph=True)
30-
3123
# Commitment objective
32-
loss_commit = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
33-
loss_commit.backward()
24+
loss_commit = F.mse_loss(z_e_x, z_q_x.detach())
25+
26+
loss = loss_recons + loss_vq + args.beta * loss_commit
27+
loss.backward()
3428

3529
# Logs
3630
writer.add_scalar('loss/train/reconstruction', loss_recons.item(), args.steps)

modules.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.distributions.normal import Normal
55
from torch.distributions import kl_divergence
66

7+
from functions import vq, vq_st
78

89
def to_scalar(arr):
910
if type(arr) == list:
@@ -73,17 +74,21 @@ def __init__(self, K, D):
7374
self.embedding.weight.data.uniform_(-1./K, 1./K)
7475

7576
def forward(self, z_e_x):
76-
# z_e_x - (B, D, H, W)
77-
# emb - (K, D)
77+
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
78+
latents = vq(z_e_x_, self.embedding.weight)
79+
return latents
7880

79-
emb = self.embedding.weight
80-
dists = torch.pow(
81-
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
82-
2
83-
).sum(2)
81+
def straight_through(self, z_e_x):
82+
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
83+
z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight)
84+
z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()
8485

85-
latents = dists.min(1)[1]
86-
return latents
86+
z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
87+
dim=0, index=indices)
88+
z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
89+
z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()
90+
91+
return z_q_x, z_q_x_bar
8792

8893

8994
class ResBlock(nn.Module):
@@ -132,16 +137,17 @@ def __init__(self, input_dim, dim, K=512):
132137
def encode(self, x):
133138
z_e_x = self.encoder(x)
134139
latents = self.codebook(z_e_x)
135-
return latents, z_e_x
140+
return latents
136141

137142
def decode(self, latents):
138143
z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2) # (B, D, H, W)
139144
x_tilde = self.decoder(z_q_x)
140-
return x_tilde, z_q_x
145+
return x_tilde
141146

142147
def forward(self, x):
143-
latents, z_e_x = self.encode(x)
144-
x_tilde, z_q_x = self.decode(latents)
148+
z_e_x = self.encoder(x)
149+
z_q_x_st, z_q_x = self.codebook.straight_through(z_e_x)
150+
x_tilde = self.decoder(z_q_x_st)
145151
return x_tilde, z_e_x, z_q_x
146152

147153

pixelcnn_prior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def generate_samples():
118118
label = label.to(device=DEVICE, dtype=torch.int64)
119119

120120
latents = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
121-
x_tilde, _ = autoencoder.decode(latents)
121+
x_tilde = autoencoder.decode(latents)
122122
images = (x_tilde.cpu().data + 1) / 2
123123

124124
save_image(

test_functions.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
3+
import numpy as np
4+
import torch
5+
6+
from functions import vq, vq_st
7+
8+
def test_vq_shape():
9+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
10+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
11+
indices = vq(inputs, codebook)
12+
13+
assert indices.size() == (2, 3, 5)
14+
assert not indices.requires_grad
15+
assert indices.dtype == torch.int64
16+
17+
def test_vq():
18+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
19+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
20+
indices = vq(inputs, codebook)
21+
22+
differences = inputs.unsqueeze(3) - codebook
23+
distances = torch.norm(differences, p=2, dim=4)
24+
25+
_, indices_torch = torch.min(distances, dim=3)
26+
27+
assert np.allclose(indices.numpy(), indices_torch.numpy())
28+
29+
def test_vq_st_shape():
30+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
31+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
32+
codes, indices = vq_st(inputs, codebook)
33+
34+
assert codes.size() == (2, 3, 5, 7)
35+
assert codes.requires_grad
36+
assert codes.dtype == torch.float32
37+
38+
assert indices.size() == (2 * 3 * 5,)
39+
assert not indices.requires_grad
40+
assert indices.dtype == torch.int64
41+
42+
def test_vq_st_gradient1():
43+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
44+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
45+
codes, _ = vq_st(inputs, codebook)
46+
47+
grad_output = torch.rand((2, 3, 5, 7))
48+
grad_inputs, = torch.autograd.grad(codes, inputs,
49+
grad_outputs=[grad_output])
50+
51+
# Straight-through estimator
52+
assert grad_inputs.size() == (2, 3, 5, 7)
53+
assert np.allclose(grad_output.numpy(), grad_inputs.numpy())
54+
55+
def test_vq_st_gradient2():
56+
inputs = torch.rand((2, 3, 5, 7), dtype=torch.float32, requires_grad=True)
57+
codebook = torch.rand((11, 7), dtype=torch.float32, requires_grad=True)
58+
codes, _ = vq_st(inputs, codebook)
59+
60+
indices = vq(inputs, codebook)
61+
codes_torch = torch.embedding(codebook, indices, padding_idx=-1,
62+
scale_grad_by_freq=False, sparse=False)
63+
64+
grad_output = torch.rand((2, 3, 5, 7), dtype=torch.float32)
65+
grad_codebook, = torch.autograd.grad(codes, codebook,
66+
grad_outputs=[grad_output])
67+
grad_codebook_torch, = torch.autograd.grad(codes_torch, codebook,
68+
grad_outputs=[grad_output])
69+
70+
# Gradient is the same as torch.embedding function
71+
assert grad_codebook.size() == (11, 7)
72+
assert np.allclose(grad_codebook.numpy(), grad_codebook_torch.numpy())

0 commit comments

Comments
 (0)