Skip to content

Commit 94453af

Browse files
refactoring code
1 parent df45f31 commit 94453af

File tree

6 files changed

+245
-38
lines changed

6 files changed

+245
-38
lines changed

modules.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
from torch.distributions.normal import Normal
5+
from torch.distributions import kl_divergence
46

57

68
def to_scalar(arr):
@@ -19,7 +21,7 @@ def weights_init(m):
1921

2022
class ResBlock(nn.Module):
2123
def __init__(self, dim):
22-
super(ResBlock, self).__init__()
24+
super().__init__()
2325
self.block = nn.Sequential(
2426
nn.ReLU(True),
2527
nn.Conv2d(dim, dim, 3, 1, 1),
@@ -31,9 +33,66 @@ def forward(self, x):
3133
return x + self.block(x)
3234

3335

34-
class AutoEncoder(nn.Module):
36+
class VQEmbedding(nn.Module):
37+
def __init__(self, K, D):
38+
super().__init__()
39+
self.embedding = nn.Embedding(K, D)
40+
self.embedding.weight.data.uniform_(-1./K, 1./K)
41+
42+
def forward(self, z_e_x):
43+
# z_e_x - (B, D, H, W)
44+
# emb - (K, D)
45+
46+
emb = self.embedding.weight
47+
dists = torch.pow(
48+
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
49+
2
50+
).sum(2)
51+
52+
latents = dists.min(1)[1]
53+
return latents
54+
55+
56+
class VAE(nn.Module):
57+
def __init__(self, input_dim, dim, z_dim):
58+
super().__init__()
59+
self.encoder = nn.Sequential(
60+
nn.Conv2d(input_dim, dim, 4, 2, 1),
61+
nn.ReLU(True),
62+
nn.Conv2d(dim, dim, 4, 2, 1),
63+
nn.ReLU(True),
64+
nn.Conv2d(dim, dim, 5, 1, 0),
65+
nn.ReLU(True),
66+
nn.Conv2d(dim, z_dim * 2, 3, 1, 0),
67+
)
68+
69+
self.decoder = nn.Sequential(
70+
nn.ConvTranspose2d(z_dim, dim, 3, 1, 0),
71+
nn.ReLU(True),
72+
nn.ConvTranspose2d(dim, dim, 5, 1, 0),
73+
nn.ReLU(True),
74+
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
75+
nn.ReLU(True),
76+
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
77+
nn.Tanh()
78+
)
79+
80+
self.apply(weights_init)
81+
82+
def forward(self, x):
83+
mu, logvar = self.encoder(x).chunk(2, dim=1)
84+
85+
q_z_x = Normal(mu, logvar.mul(.5).exp())
86+
p_z = Normal(torch.zeros_like(mu), torch.ones_like(logvar))
87+
kl_div = kl_divergence(q_z_x, p_z).sum(1).mean()
88+
89+
x_tilde = self.decoder(q_z_x.rsample())
90+
return x_tilde, kl_div
91+
92+
93+
class VectorQuantizedAE(nn.Module):
3594
def __init__(self, input_dim, dim, K=512):
36-
super(AutoEncoder, self).__init__()
95+
super().__init__()
3796
self.encoder = nn.Sequential(
3897
nn.Conv2d(input_dim, dim, 4, 2, 1),
3998
nn.ReLU(True),
@@ -42,9 +101,7 @@ def __init__(self, input_dim, dim, K=512):
42101
ResBlock(dim),
43102
)
44103

45-
self.embedding = nn.Embedding(K, dim)
46-
# self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
47-
self.embedding.weight.data.uniform_(-1./K, 1./K)
104+
self.codebook = VQEmbedding(K, dim)
48105

49106
self.decoder = nn.Sequential(
50107
ResBlock(dim),
@@ -60,20 +117,11 @@ def __init__(self, input_dim, dim, K=512):
60117

61118
def encode(self, x):
62119
z_e_x = self.encoder(x)
63-
64-
z_e_x_transp = z_e_x.permute(0, 2, 3, 1) # (B, H, W, C)
65-
emb = self.embedding.weight.transpose(0, 1) # (C, K)
66-
dists = torch.pow(
67-
z_e_x_transp.unsqueeze(4) - emb[None, None, None],
68-
2
69-
).sum(-2)
70-
latents = dists.min(-1)[1]
120+
latents = self.codebook(z_e_x)
71121
return latents, z_e_x
72122

73123
def decode(self, latents):
74-
shp = latents.size() + (-1, )
75-
z_q_x = self.embedding(latents.view(latents.size(0), -1)) # (B * H * W, C)
76-
z_q_x = z_q_x.view(*shp).permute(0, 3, 1, 2) # (B, C, H, W)
124+
z_q_x = self.codebook.embedding(latents).permute(0, 3, 1, 2) # (B, D, H, W)
77125
x_tilde = self.decoder(z_q_x)
78126
return x_tilde, z_q_x
79127

@@ -191,8 +239,10 @@ def forward(self, x, label):
191239

192240
def generate(self, label, shape=(8, 8), batch_size=64):
193241
param = next(self.parameters())
194-
x = torch.zeros((batch_size, *shape),
195-
dtype=torch.int64, device=param.device)
242+
x = torch.zeros(
243+
(batch_size, *shape),
244+
dtype=torch.int64, device=param.device
245+
)
196246

197247
for i in range(shape[0]):
198248
for j in range(shape[1]):

pixelcnn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
LATENT_SHAPE = (8, 8) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
1818
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
19-
DIM = 128
19+
DIM = 64
2020
VAE_DIM = 256
2121
N_LAYERS = 15
2222
K = 512
23-
LR = 3e-4
23+
LR = 1e-3
2424

2525
DEVICE = torch.device('cuda') # torch.device('cpu')
2626

@@ -45,13 +45,13 @@
4545

4646
autoencoder = AutoEncoder(INPUT_DIM, VAE_DIM, K).to(DEVICE)
4747
autoencoder.load_state_dict(
48-
torch.load('models/{}_autoencoder.pt'.format(DATASET))
48+
torch.load('models/{}_vqvae.pt'.format(DATASET))
4949
)
5050
autoencoder.eval()
5151

5252
model = GatedPixelCNN(K, DIM, N_LAYERS).to(DEVICE)
5353
criterion = nn.CrossEntropyLoss().to(DEVICE)
54-
opt = torch.optim.Adam(model.parameters(), lr=LR)
54+
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)
5555

5656

5757
def train():

samples/reconstructions_CIFAR10.png

137 KB
Loading

samples/reconstructions_MNIST.png

11.1 KB
Loading

vae.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import numpy as np
2+
import time
3+
4+
import torch
5+
import torch.nn.functional as F
6+
from torch.distributions.normal import Normal
7+
8+
from torchvision import datasets, transforms
9+
from torchvision.utils import save_image
10+
11+
from modules import VAE
12+
13+
14+
BATCH_SIZE = 32
15+
N_EPOCHS = 100
16+
PRINT_INTERVAL = 500
17+
DATASET = 'CIFAR10' # CIFAR10 | MNIST | FashionMNIST
18+
NUM_WORKERS = 4
19+
20+
INPUT_DIM = 3
21+
DIM = 256
22+
Z_DIM = 128
23+
LR = 3e-4
24+
25+
26+
preproc_transform = transforms.Compose([
27+
transforms.ToTensor(),
28+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
29+
])
30+
train_loader = torch.utils.data.DataLoader(
31+
eval('datasets.'+DATASET)(
32+
'../data/{}/'.format(DATASET), train=True, download=True,
33+
transform=preproc_transform,
34+
), batch_size=BATCH_SIZE, shuffle=False,
35+
num_workers=NUM_WORKERS, pin_memory=True
36+
)
37+
test_loader = torch.utils.data.DataLoader(
38+
eval('datasets.'+DATASET)(
39+
'../data/{}/'.format(DATASET), train=False,
40+
transform=preproc_transform
41+
), batch_size=BATCH_SIZE, shuffle=False,
42+
num_workers=NUM_WORKERS, pin_memory=True
43+
)
44+
45+
model = VAE(INPUT_DIM, DIM, Z_DIM).cuda()
46+
print(model)
47+
opt = torch.optim.Adam(model.parameters(), lr=LR, amsgrad=True)
48+
49+
50+
def train():
51+
train_loss = []
52+
model.train()
53+
for batch_idx, (x, _) in enumerate(train_loader):
54+
start_time = time.time()
55+
x = x.cuda()
56+
57+
x_tilde, kl_d = model(x)
58+
loss_recons = F.mse_loss(x_tilde, x, size_average=False) / x.size(0)
59+
loss = loss_recons + kl_d
60+
61+
nll = -Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x)
62+
log_px = nll.mean().item() - np.log(128) + kl_d.item()
63+
log_px /= np.log(2)
64+
65+
opt.zero_grad()
66+
loss.backward()
67+
opt.step()
68+
69+
train_loss.append([log_px, loss.item()])
70+
71+
if (batch_idx + 1) % PRINT_INTERVAL == 0:
72+
print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {:5.3f} ms/batch'.format(
73+
batch_idx * len(x), len(train_loader.dataset),
74+
PRINT_INTERVAL * batch_idx / len(train_loader),
75+
np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
76+
1000 * (time.time() - start_time)
77+
))
78+
79+
80+
def test():
81+
start_time = time.time()
82+
val_loss = []
83+
model.eval()
84+
with torch.no_grad():
85+
for batch_idx, (x, _) in enumerate(test_loader):
86+
x = x.cuda()
87+
x_tilde, kl_d = model(x)
88+
loss_recons = F.mse_loss(x_tilde, x, size_average=False) / x.size(0)
89+
loss = loss_recons + kl_d
90+
val_loss.append(loss.item())
91+
92+
print('\nValidation Completed!\tLoss: {:5.4f} Time: {:5.3f} s'.format(
93+
np.asarray(val_loss).mean(0),
94+
time.time() - start_time
95+
))
96+
return np.asarray(val_loss).mean(0)
97+
98+
99+
def generate_reconstructions():
100+
model.eval()
101+
x, _ = test_loader.__iter__().next()
102+
x = x[:32].cuda()
103+
x_tilde, kl_div = model(x)
104+
105+
x_cat = torch.cat([x, x_tilde], 0)
106+
images = (x_cat.cpu().data + 1) / 2
107+
108+
save_image(
109+
images,
110+
'samples/vae_reconstructions_{}.png'.format(DATASET),
111+
nrow=8
112+
)
113+
114+
115+
def generate_samples():
116+
model.eval()
117+
z_e_x = torch.randn(64, Z_DIM, 1, 1).cuda()
118+
x_tilde = model.decoder(z_e_x)
119+
120+
images = (x_tilde.cpu().data + 1) / 2
121+
122+
save_image(
123+
images,
124+
'samples/vae_samples_{}.png'.format(DATASET),
125+
nrow=8
126+
)
127+
128+
129+
BEST_LOSS = 99999
130+
LAST_SAVED = -1
131+
for epoch in range(1, N_EPOCHS):
132+
print("Epoch {}:".format(epoch))
133+
train()
134+
cur_loss = test()
135+
136+
if cur_loss <= BEST_LOSS:
137+
BEST_LOSS = cur_loss
138+
LAST_SAVED = epoch
139+
print("Saving model!")
140+
torch.save(model.state_dict(), 'models/{}_vae.pt'.format(DATASET))
141+
else:
142+
print("Not saving model! Last saved: {}".format(LAST_SAVED))
143+
144+
generate_reconstructions()
145+
generate_samples()

0 commit comments

Comments
 (0)