Skip to content

Commit ca3d943

Browse files
minor bug fixes
1 parent f17c0a2 commit ca3d943

File tree

4 files changed

+51
-64
lines changed

4 files changed

+51
-64
lines changed

modules.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,6 @@ def weights_init(m):
1919
m.bias.data.fill_(0)
2020

2121

22-
class ResBlock(nn.Module):
23-
def __init__(self, dim):
24-
super().__init__()
25-
self.block = nn.Sequential(
26-
nn.ReLU(True),
27-
nn.Conv2d(dim, dim, 3, 1, 1),
28-
nn.ReLU(True),
29-
nn.Conv2d(dim, dim, 1),
30-
)
31-
32-
def forward(self, x):
33-
return x + self.block(x)
34-
35-
36-
class VQEmbedding(nn.Module):
37-
def __init__(self, K, D):
38-
super().__init__()
39-
self.embedding = nn.Embedding(K, D)
40-
self.embedding.weight.data.uniform_(-1./K, 1./K)
41-
42-
def forward(self, z_e_x):
43-
# z_e_x - (B, D, H, W)
44-
# emb - (K, D)
45-
46-
emb = self.embedding.weight
47-
dists = torch.pow(
48-
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
49-
2
50-
).sum(2)
51-
52-
latents = dists.min(1)[1]
53-
return latents
54-
55-
5622
class VAE(nn.Module):
5723
def __init__(self, input_dim, dim, z_dim):
5824
super().__init__()
@@ -90,11 +56,48 @@ def forward(self, x):
9056
return x_tilde, kl_div
9157

9258

93-
class VectorQuantizedAE(nn.Module):
59+
class VQEmbedding(nn.Module):
60+
def __init__(self, K, D):
61+
super().__init__()
62+
self.embedding = nn.Embedding(K, D)
63+
self.embedding.weight.data.uniform_(-1./K, 1./K)
64+
65+
def forward(self, z_e_x):
66+
# z_e_x - (B, D, H, W)
67+
# emb - (K, D)
68+
69+
emb = self.embedding.weight
70+
dists = torch.pow(
71+
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
72+
2
73+
).sum(2)
74+
75+
latents = dists.min(1)[1]
76+
return latents
77+
78+
79+
class ResBlock(nn.Module):
80+
def __init__(self, dim):
81+
super().__init__()
82+
self.block = nn.Sequential(
83+
nn.ReLU(True),
84+
nn.Conv2d(dim, dim, 3, 1, 1),
85+
nn.BatchNorm2d(dim),
86+
nn.ReLU(True),
87+
nn.Conv2d(dim, dim, 1),
88+
nn.BatchNorm2d(dim)
89+
)
90+
91+
def forward(self, x):
92+
return x + self.block(x)
93+
94+
95+
class VectorQuantizedVAE(nn.Module):
9496
def __init__(self, input_dim, dim, K=512):
9597
super().__init__()
9698
self.encoder = nn.Sequential(
9799
nn.Conv2d(input_dim, dim, 4, 2, 1),
100+
nn.BatchNorm2d(dim),
98101
nn.ReLU(True),
99102
nn.Conv2d(dim, dim, 4, 2, 1),
100103
ResBlock(dim),
@@ -108,6 +111,7 @@ def __init__(self, input_dim, dim, K=512):
108111
ResBlock(dim),
109112
nn.ReLU(True),
110113
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
114+
nn.BatchNorm2d(dim),
111115
nn.ReLU(True),
112116
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
113117
nn.Tanh()

pixelcnn_baseline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
1515
NUM_WORKERS = 4
1616

17-
LATENT_SHAPE = (28, 28) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
17+
IMAGE_SHAPE = (28, 28) # (32, 32) | (28, 28)
1818
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
1919
K = 256
2020
DIM = 64
@@ -100,7 +100,7 @@ def generate_samples():
100100
label = torch.arange(10).expand(10, 10).contiguous().view(-1)
101101
label = label.long().cuda()
102102

103-
x_tilde = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
103+
x_tilde = model.generate(label, shape=IMAGE_SHAPE, batch_size=100)
104104
images = x_tilde.cpu().data.float() / (K - 1)
105105

106106
save_image(

pixelcnn_prior.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from torchvision import datasets, transforms
4-
from modules import AutoEncoder, GatedPixelCNN, to_scalar
4+
from modules import VectorQuantizedVAE, GatedPixelCNN
55
import numpy as np
66
from torchvision.utils import save_image
77
import time
@@ -11,11 +11,11 @@
1111
N_EPOCHS = 100
1212
PRINT_INTERVAL = 100
1313
ALWAYS_SAVE = True
14-
DATASET = 'CIFAR10' # CIFAR10 | MNIST | FashionMNIST
14+
DATASET = 'MNIST' # CIFAR10 | MNIST | FashionMNIST
1515
NUM_WORKERS = 4
1616

17-
LATENT_SHAPE = (8, 8) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18-
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
17+
LATENT_SHAPE = (7, 7) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18+
INPUT_DIM = 1 # 3 (RGB) | 1 (Grayscale)
1919
DIM = 64
2020
VAE_DIM = 256
2121
N_LAYERS = 15
@@ -43,7 +43,7 @@
4343
num_workers=NUM_WORKERS, pin_memory=True
4444
)
4545

46-
autoencoder = AutoEncoder(INPUT_DIM, VAE_DIM, K).to(DEVICE)
46+
autoencoder = VectorQuantizedVAE(INPUT_DIM, VAE_DIM, K).to(DEVICE)
4747
autoencoder.load_state_dict(
4848
torch.load('models/{}_vqvae.pt'.format(DATASET))
4949
)
@@ -78,7 +78,7 @@ def train():
7878
loss.backward()
7979
opt.step()
8080

81-
train_loss.append(to_scalar(loss))
81+
train_loss.append(loss.item())
8282

8383
if (batch_idx + 1) % PRINT_INTERVAL == 0:
8484
print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
@@ -104,7 +104,7 @@ def test():
104104
logits.view(-1, K),
105105
latents.view(-1)
106106
)
107-
val_loss.append(to_scalar(loss))
107+
val_loss.append(loss.item())
108108

109109
print('Validation Completed!\tLoss: {} Time: {}'.format(
110110
np.asarray(val_loss).mean(0),
@@ -128,25 +128,8 @@ def generate_samples():
128128
)
129129

130130

131-
def generate_reconstructions():
132-
x, _ = test_loader.__iter__().next()
133-
x = x[:32].to(DEVICE)
134-
135-
latents, _ = autoencoder.encode(x)
136-
x_tilde, _ = autoencoder.decode(latents)
137-
x_cat = torch.cat([x, x_tilde], 0)
138-
images = (x_cat.cpu().data + 1) / 2
139-
140-
save_image(
141-
images,
142-
'samples/reconstructions_{}.png'.format(DATASET),
143-
nrow=8
144-
)
145-
146-
147131
BEST_LOSS = 999
148132
LAST_SAVED = -1
149-
generate_reconstructions()
150133
for epoch in range(1, N_EPOCHS):
151134
print("\nEpoch {}:".format(epoch))
152135
train()

vqvae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.utils import save_image
99
from torch.distributions.normal import Normal
1010

11-
from modules import VectorQuantizedAE, to_scalar
11+
from modules import VectorQuantizedVAE, to_scalar
1212

1313

1414
BATCH_SIZE = 32
@@ -21,7 +21,7 @@
2121
DIM = 256
2222
K = 512
2323
LAMDA = 1
24-
LR = 3e-4
24+
LR = 1e-3
2525

2626
DEVICE = torch.device('cuda') # torch.device('cpu')
2727

@@ -49,7 +49,7 @@
4949
num_workers=NUM_WORKERS, pin_memory=True
5050
)
5151

52-
model = VectorQuantizedAE(INPUT_DIM, DIM, K).to(DEVICE)
52+
model = VectorQuantizedVAE(INPUT_DIM, DIM, K).to(DEVICE)
5353
print(model)
5454
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)
5555

0 commit comments

Comments
 (0)