Skip to content

Commit fa4504c

Browse files
committed
Replace quantization in VQEmbedding by vq function
1 parent 859d429 commit fa4504c

File tree

4 files changed

+22
-28
lines changed

4 files changed

+22
-28
lines changed

miniimagenet_pixelcnn_prior.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def train(data_loader, model, prior, optimizer, args, writer):
1414
for images, labels in data_loader:
1515
with torch.no_grad():
1616
images = images.to(args.device)
17-
latents, _ = model.encode(images)
17+
latents = model.encode(images)
1818
latents = latents.detach()
1919

2020
labels = labels.to(args.device)
@@ -39,7 +39,7 @@ def test(data_loader, model, prior, args, writer):
3939
images = images.to(args.device)
4040
labels = labels.to(args.device)
4141

42-
latents, _ = model.encode(images)
42+
latents = model.encode(images)
4343
latents = latents.detach()
4444
logits = prior(latents, labels)
4545
logits = logits.permute(0, 2, 3, 1).contiguous()

miniimagenet_vqvae.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,16 @@ def train(data_loader, model, optimizer, args, writer):
1515

1616
optimizer.zero_grad()
1717
x_tilde, z_e_x, z_q_x = model(images)
18-
z_q_x.retain_grad()
1918

19+
# Reconstruction loss
2020
loss_recons = F.mse_loss(x_tilde, images)
21-
loss_recons.backward(retain_graph=True)
22-
23-
# Straight-through estimator
24-
z_e_x.backward(z_q_x.grad, retain_graph=True)
25-
2621
# Vector quantization objective
27-
model.codebook.embedding.zero_grad()
2822
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
29-
loss_vq.backward(retain_graph=True)
30-
3123
# Commitment objective
32-
loss_commit = args.beta * F.mse_loss(z_e_x, z_q_x.detach())
33-
loss_commit.backward()
24+
loss_commit = F.mse_loss(z_e_x, z_q_x.detach())
25+
26+
loss = loss_recons + loss_vq + args.beta * loss_commit
27+
loss.backward()
3428

3529
# Logs
3630
writer.add_scalar('loss/train/reconstruction', loss_recons.item(), args.steps)

modules.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.distributions.normal import Normal
55
from torch.distributions import kl_divergence
66

7+
from functions import vq, vq_st
78

89
def to_scalar(arr):
910
if type(arr) == list:
@@ -73,18 +74,16 @@ def __init__(self, K, D):
7374
self.embedding.weight.data.uniform_(-1./K, 1./K)
7475

7576
def forward(self, z_e_x):
76-
# z_e_x - (B, D, H, W)
77-
# emb - (K, D)
78-
79-
emb = self.embedding.weight
80-
dists = torch.pow(
81-
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
82-
2
83-
).sum(2)
84-
85-
latents = dists.min(1)[1]
77+
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
78+
latents = vq(z_e_x_, self.embedding.weight)
8679
return latents
8780

81+
def straight_through(self, z_e_x):
82+
z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
83+
z_q_x_ = vq_st(z_e_x_, self.embedding.weight)
84+
z_q_x = z_q_x_.permute(0, 3, 1, 2)
85+
return z_q_x
86+
8887

8988
class ResBlock(nn.Module):
9089
def __init__(self, dim):
@@ -132,16 +131,17 @@ def __init__(self, input_dim, dim, K=512):
132131
def encode(self, x):
133132
z_e_x = self.encoder(x)
134133
latents = self.codebook(z_e_x)
135-
return latents, z_e_x
134+
return latents
136135

137136
def decode(self, latents):
138137
z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2) # (B, D, H, W)
139138
x_tilde = self.decoder(z_q_x)
140-
return x_tilde, z_q_x
139+
return x_tilde
141140

142141
def forward(self, x):
143-
latents, z_e_x = self.encode(x)
144-
x_tilde, z_q_x = self.decode(latents)
142+
z_e_x = self.encoder(x)
143+
z_q_x = self.codebook.straight_through(z_e_x)
144+
x_tilde = self.decoder(z_q_x)
145145
return x_tilde, z_e_x, z_q_x
146146

147147

pixelcnn_prior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def generate_samples():
118118
label = label.to(device=DEVICE, dtype=torch.int64)
119119

120120
latents = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
121-
x_tilde, _ = autoencoder.decode(latents)
121+
x_tilde = autoencoder.decode(latents)
122122
images = (x_tilde.cpu().data + 1) / 2
123123

124124
save_image(

0 commit comments

Comments
 (0)