Skip to content

Commit 4ee173d

Browse files
improving code and adding pixelcnn
1 parent eccdbef commit 4ee173d

File tree

3 files changed

+281
-40
lines changed

3 files changed

+281
-40
lines changed

main.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,40 @@
88
import time
99

1010

11-
kwargs = {'num_workers': 1, 'pin_memory': True}
11+
BATCH_SIZE = 128
12+
NUM_WORKERS = 4
13+
LR = 2e-4
14+
K = 256
15+
LAMDA = 0.25
16+
PRINT_INTERVAL = 100
17+
N_EPOCHS = 100
18+
19+
20+
preproc_transform = transforms.Compose([
21+
transforms.ToTensor(),
22+
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
23+
])
1224
train_loader = torch.utils.data.DataLoader(
1325
datasets.CIFAR10(
1426
'../data/cifar10/', train=True, download=True,
15-
transform=transforms.Compose(
16-
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
17-
)
18-
), batch_size=64, shuffle=False, **kwargs
27+
transform=preproc_transform,
28+
), batch_size=BATCH_SIZE, shuffle=False,
29+
num_workers=NUM_WORKERS, pin_memory=True
1930
)
2031

2132
test_loader = torch.utils.data.DataLoader(
2233
datasets.CIFAR10(
2334
'../data/cifar10/', train=False,
24-
transform=transforms.Compose(
25-
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
26-
)
27-
), batch_size=32, shuffle=False, **kwargs
35+
transform=preproc_transform
36+
), batch_size=BATCH_SIZE, shuffle=False,
37+
num_workers=NUM_WORKERS, pin_memory=True
2838
)
29-
test_data = list(test_loader)
3039

31-
model = AutoEncoder().cuda()
32-
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
40+
model = AutoEncoder(K).cuda()
41+
opt = torch.optim.Adam(model.parameters(), lr=LR)
3342

3443

35-
def train(epoch):
44+
def train():
3645
train_loss = []
3746
for batch_idx, (data, _) in enumerate(train_loader):
3847
start_time = time.time()
@@ -43,43 +52,75 @@ def train(epoch):
4352
x_tilde, z_e_x, z_q_x = model(x)
4453
z_q_x.retain_grad()
4554

46-
loss_recons = F.l1_loss(x_tilde, x)
55+
loss_recons = F.mse_loss(x_tilde, x)
4756
loss_recons.backward(retain_graph=True)
4857

4958
# Straight-through estimator
5059
z_e_x.backward(z_q_x.grad, retain_graph=True)
5160

5261
# Vector quantization objective
62+
model.embedding.zero_grad()
5363
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
5464
loss_vq.backward(retain_graph=True)
5565

5666
# Commitment objective
57-
loss_commit = 0.25 * F.mse_loss(z_e_x, z_q_x.detach())
67+
loss_commit = LAMDA * F.mse_loss(z_e_x, z_q_x.detach())
5868
loss_commit.backward()
5969
opt.step()
6070

6171
train_loss.append(to_scalar([loss_recons, loss_vq]))
6272

6373
if (batch_idx + 1) % 100 == 0:
64-
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
65-
epoch, batch_idx * len(data), len(train_loader.dataset),
74+
print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
75+
batch_idx * len(data), len(train_loader.dataset),
6676
100. * batch_idx / len(train_loader),
6777
np.asarray(train_loss)[-100:].mean(0),
6878
time.time() - start_time
6979
))
7080

7181

7282
def test():
73-
x = Variable(test_data[0][0]).cuda()
83+
start_time = time.time()
84+
val_loss = []
85+
for batch_idx, (data, _) in enumerate(test_loader):
86+
x = Variable(data, volatile=True).cuda()
87+
x_tilde, z_e_x, z_q_x = model(x)
88+
loss_recons = F.mse_loss(x_tilde, x)
89+
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
90+
val_loss.append(to_scalar([loss_recons, loss_vq]))
91+
92+
print('\nValidation Completed!\tLoss: {} Time: {:5.3f}'.format(
93+
np.asarray(val_loss).mean(0),
94+
time.time() - start_time
95+
))
96+
return np.asarray(val_loss).mean(0)
97+
98+
99+
def generate_samples():
100+
x, _ = test_loader.__iter__().next()
101+
x = Variable(x[:32]).cuda()
74102
x_tilde, _, _ = model(x)
75-
x_tilde = (x_tilde+1)/2
76-
x = (x+1)/2
103+
# x_tilde = (x_tilde + 1)/2
104+
# x = (x + 1)/2
77105

78106
x_cat = torch.cat([x, x_tilde], 0)
79107
images = x_cat.cpu().data
80108
save_image(images, './sample_cifar.png', nrow=8)
81109

82110

83-
for i in range(100):
84-
train(i)
85-
test()
111+
BEST_LOSS = 999
112+
LAST_SAVED = -1
113+
for epoch in range(1, N_EPOCHS):
114+
print("Epoch {}:".format(epoch))
115+
train()
116+
cur_loss, _ = test()
117+
118+
if cur_loss <= BEST_LOSS:
119+
BEST_LOSS = cur_loss
120+
LAST_SAVED = epoch
121+
print("Saving model!")
122+
torch.save(model.state_dict(), 'best_autoencoder.pt')
123+
else:
124+
print("Not saving model! Last saved: {}".format(LAST_SAVED))
125+
126+
generate_samples()

modules.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
35

46

57
def to_scalar(arr):
@@ -9,13 +11,19 @@ def to_scalar(arr):
911
return arr.cpu().data.tolist()[0]
1012

1113

14+
def weights_init(m):
15+
classname = m.__class__.__name__
16+
if classname.find('Conv') != -1:
17+
nn.init.xavier_uniform(m.weight.data)
18+
m.bias.data.fill_(0)
19+
20+
1221
class ResBlock(nn.Module):
1322
def __init__(self, dim):
1423
super(ResBlock, self).__init__()
1524
self.block = nn.Sequential(
1625
nn.ReLU(True),
1726
nn.Conv2d(dim, dim, 3, 1, 1),
18-
nn.BatchNorm2d(dim),
1927
nn.ReLU(True),
2028
nn.Conv2d(dim, dim, 1)
2129
)
@@ -25,39 +33,33 @@ def forward(self, x):
2533

2634

2735
class AutoEncoder(nn.Module):
28-
def __init__(self):
36+
def __init__(self, K=512):
2937
super(AutoEncoder, self).__init__()
3038
self.encoder = nn.Sequential(
3139
nn.Conv2d(3, 256, 4, 2, 1),
32-
nn.BatchNorm2d(256),
3340
nn.ReLU(True),
3441
nn.Conv2d(256, 256, 4, 2, 1),
35-
nn.BatchNorm2d(256),
42+
nn.ReLU(True),
3643
ResBlock(256),
37-
nn.BatchNorm2d(256),
3844
ResBlock(256),
39-
nn.BatchNorm2d(256)
4045
)
4146

42-
self.embedding = nn.Embedding(512, 256)
43-
self.embedding.weight.data.copy_(1./512 * torch.randn(512, 256))
47+
self.embedding = nn.Embedding(K, 256)
48+
self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
4449

4550
self.decoder = nn.Sequential(
4651
ResBlock(256),
47-
nn.BatchNorm2d(256),
4852
ResBlock(256),
49-
nn.BatchNorm2d(256),
50-
nn.ReLU(True),
5153
nn.ConvTranspose2d(256, 256, 4, 2, 1),
52-
nn.BatchNorm2d(256),
5354
nn.ReLU(True),
5455
nn.ConvTranspose2d(256, 3, 4, 2, 1),
55-
nn.Tanh()
56+
nn.Sigmoid()
5657
)
5758

58-
def forward(self, x):
59+
self.apply(weights_init)
60+
61+
def encode(self, x):
5962
z_e_x = self.encoder(x)
60-
B, C, H, W = z_e_x.size()
6163

6264
z_e_x_transp = z_e_x.permute(0, 2, 3, 1) # (B, H, W, C)
6365
emb = self.embedding.weight.transpose(0, 1) # (C, K)
@@ -66,8 +68,78 @@ def forward(self, x):
6668
2
6769
).sum(-2)
6870
latents = dists.min(-1)[1]
71+
return latents, z_e_x
6972

70-
z_q_x = self.embedding(latents.view(latents.size(0), -1))
71-
z_q_x = z_q_x.view(B, H, W, C).permute(0, 3, 1, 2)
73+
def decode(self, latents):
74+
shp = latents.size() + (-1, )
75+
z_q_x = self.embedding(latents.view(latents.size(0), -1)) # (B * H * W, C)
76+
z_q_x = z_q_x.view(*shp).permute(0, 3, 1, 2) # (B, C, H, W)
7277
x_tilde = self.decoder(z_q_x)
78+
return x_tilde, z_q_x
79+
80+
def forward(self, x):
81+
latents, z_e_x = self.encode(x)
82+
x_tilde, z_q_x = self.decode(latents)
7383
return x_tilde, z_e_x, z_q_x
84+
85+
86+
class MaskedConv2d(nn.Conv2d):
87+
def __init__(self, mask_type, *args, **kwargs):
88+
super(MaskedConv2d, self).__init__(*args, **kwargs)
89+
assert mask_type in {'A', 'B'}
90+
self.register_buffer('mask', self.weight.data.clone())
91+
_, _, kH, kW = self.weight.size()
92+
self.mask.fill_(1)
93+
self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
94+
self.mask[:, :, kH // 2 + 1:] = 0
95+
96+
def forward(self, x):
97+
self.weight.data *= self.mask
98+
return super(MaskedConv2d, self).forward(x)
99+
100+
101+
class PixelCNN(nn.Module):
102+
def __init__(self, dim=64, n_layers=4):
103+
super().__init__()
104+
self.dim = 64
105+
106+
# Create embedding layer to embed input
107+
self.embedding = nn.Embedding(256, dim)
108+
109+
# Building the PixelCNN layer by layer
110+
net = []
111+
112+
# Initial block with Mask-A convolution
113+
# Rest with Mask-B convolutions
114+
for i in range(n_layers):
115+
mask_type = 'A' if i == 0 else 'B'
116+
net.extend([
117+
MaskedConv2d(mask_type, dim, dim, 7, 1, 3, bias=False),
118+
nn.BatchNorm2d(dim),
119+
nn.ReLU(True)
120+
])
121+
122+
# Add the output layer
123+
net.append(nn.Conv2d(dim, 256, 1))
124+
125+
self.net = nn.Sequential(*net)
126+
127+
def forward(self, x):
128+
shp = x.size() + (-1, )
129+
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
130+
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
131+
return self.net(x)
132+
133+
def generate(self, batch_size=64):
134+
x = Variable(
135+
torch.zeros(64, 8, 8).long()
136+
).cuda()
137+
138+
for i in range(8):
139+
for j in range(8):
140+
logits = self.forward(x)
141+
probs = F.softmax(logits[:, :, i, j], -1)
142+
x.data[:, i, j].copy_(
143+
probs.multinomial(1).squeeze().data
144+
)
145+
return x

0 commit comments

Comments
 (0)