Skip to content

Commit bc99680

Browse files
making code neater
1 parent 64f4b56 commit bc99680

File tree

5 files changed

+112
-83
lines changed

5 files changed

+112
-83
lines changed

main.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,52 @@
22
import torch.nn.functional as F
33
from torchvision import datasets, transforms
44
from modules import AutoEncoder, to_scalar
5-
from torch.autograd import Variable
65
import numpy as np
76
from torchvision.utils import save_image
87
import time
98

109

1110
BATCH_SIZE = 128
11+
N_EPOCHS = 100
12+
PRINT_INTERVAL = 100
13+
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
1214
NUM_WORKERS = 4
13-
LR = 2e-4
15+
16+
INPUT_DIM = 1 # 3 (RGB) | 1 (Grayscale)
17+
DIM = 256
1418
K = 512
1519
LAMDA = 1
16-
PRINT_INTERVAL = 100
17-
N_EPOCHS = 100
20+
LR = 2e-4
1821

1922

2023
preproc_transform = transforms.Compose([
2124
transforms.ToTensor(),
2225
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
2326
])
2427
train_loader = torch.utils.data.DataLoader(
25-
datasets.CIFAR10(
26-
'../data/cifar10/', train=True, download=True,
28+
eval('datasets.'+DATASET)(
29+
'../data/{}/'.format(DATASET), train=True, download=True,
2730
transform=preproc_transform,
2831
), batch_size=BATCH_SIZE, shuffle=False,
2932
num_workers=NUM_WORKERS, pin_memory=True
3033
)
31-
3234
test_loader = torch.utils.data.DataLoader(
33-
datasets.CIFAR10(
34-
'../data/cifar10/', train=False,
35+
eval('datasets.'+DATASET)(
36+
'../data/{}/'.format(DATASET), train=False,
3537
transform=preproc_transform
3638
), batch_size=BATCH_SIZE, shuffle=False,
3739
num_workers=NUM_WORKERS, pin_memory=True
3840
)
3941

40-
model = AutoEncoder(K).cuda()
42+
model = AutoEncoder(INPUT_DIM, DIM, K).cuda()
4143
opt = torch.optim.Adam(model.parameters(), lr=LR)
4244

4345

4446
def train():
4547
train_loss = []
46-
for batch_idx, (data, _) in enumerate(train_loader):
48+
for batch_idx, (x, _) in enumerate(train_loader):
4749
start_time = time.time()
48-
x = Variable(data, requires_grad=False).cuda()
50+
x = x.cuda()
4951

5052
opt.zero_grad()
5153

@@ -70,20 +72,20 @@ def train():
7072

7173
train_loss.append(to_scalar([loss_recons, loss_vq]))
7274

73-
if (batch_idx + 1) % 100 == 0:
75+
if (batch_idx + 1) % PRINT_INTERVAL == 0:
7476
print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
75-
batch_idx * len(data), len(train_loader.dataset),
76-
100. * batch_idx / len(train_loader),
77-
np.asarray(train_loss)[-100:].mean(0),
77+
batch_idx * len(x), len(train_loader.dataset),
78+
PRINT_INTERVAL * batch_idx / len(train_loader),
79+
np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
7880
time.time() - start_time
7981
))
8082

8183

8284
def test():
8385
start_time = time.time()
8486
val_loss = []
85-
for batch_idx, (data, _) in enumerate(test_loader):
86-
x = Variable(data, volatile=True).cuda()
87+
for batch_idx, (x, _) in enumerate(test_loader):
88+
x = x.cuda()
8789
x_tilde, z_e_x, z_q_x = model(x)
8890
loss_recons = F.mse_loss(x_tilde, x)
8991
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
@@ -98,12 +100,17 @@ def test():
98100

99101
def generate_samples():
100102
x, _ = test_loader.__iter__().next()
101-
x = Variable(x[:32]).cuda()
103+
x = x[:32].cuda()
102104
x_tilde, _, _ = model(x)
103105

104106
x_cat = torch.cat([x, x_tilde], 0)
105107
images = (x_cat.cpu().data + 1) / 2
106-
save_image(images, './sample_cifar.png', nrow=8)
108+
109+
save_image(
110+
images,
111+
'samples/reconstructions_{}.png'.format(DATASET),
112+
nrow=8
113+
)
107114

108115

109116
BEST_LOSS = 999
@@ -117,7 +124,7 @@ def generate_samples():
117124
BEST_LOSS = cur_loss
118125
LAST_SAVED = epoch
119126
print("Saving model!")
120-
torch.save(model.state_dict(), 'best_autoencoder.pt')
127+
torch.save(model.state_dict(), 'models/{}_autoencoder.pt'.format(DATASET))
121128
else:
122129
print("Not saving model! Last saved: {}".format(LAST_SAVED))
123130

modules.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4-
from torch.autograd import Variable
54

65

76
def to_scalar(arr):
@@ -33,27 +32,27 @@ def forward(self, x):
3332

3433

3534
class AutoEncoder(nn.Module):
36-
def __init__(self, K=512):
35+
def __init__(self, input_dim, dim, K=512):
3736
super(AutoEncoder, self).__init__()
3837
self.encoder = nn.Sequential(
39-
nn.Conv2d(3, 256, 4, 2, 1),
38+
nn.Conv2d(input_dim, dim, 4, 2, 1),
4039
nn.ReLU(True),
41-
nn.Conv2d(256, 256, 4, 2, 1),
42-
ResBlock(256),
43-
ResBlock(256),
40+
nn.Conv2d(dim, dim, 4, 2, 1),
41+
ResBlock(dim),
42+
ResBlock(dim),
4443
)
4544

46-
self.embedding = nn.Embedding(K, 256)
45+
self.embedding = nn.Embedding(K, dim)
4746
# self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
4847
self.embedding.weight.data.uniform_(-1./K, 1./K)
4948

5049
self.decoder = nn.Sequential(
51-
ResBlock(256),
52-
ResBlock(256),
50+
ResBlock(dim),
51+
ResBlock(dim),
5352
nn.ReLU(True),
54-
nn.ConvTranspose2d(256, 256, 4, 2, 1),
53+
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
5554
nn.ReLU(True),
56-
nn.ConvTranspose2d(256, 3, 4, 2, 1),
55+
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
5756
nn.Tanh()
5857
)
5958

@@ -94,12 +93,16 @@ def forward(self, x):
9493

9594

9695
class GatedMaskedConv2d(nn.Module):
97-
def __init__(self, mask_type, dim, kernel, residual=True):
96+
def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10):
9897
super().__init__()
9998
assert kernel % 2 == 1, print("Kernel size must be odd")
10099
self.mask_type = mask_type
101100
self.residual = residual
102101

102+
self.class_cond_embedding = nn.Embedding(
103+
n_classes, 2 * dim
104+
)
105+
103106
kernel_shp = (kernel // 2 + 1, kernel) # (ceil(n/2), n)
104107
padding_shp = (kernel // 2, kernel // 2)
105108
self.vert_stack = nn.Conv2d(
@@ -124,19 +127,20 @@ def make_causal(self):
124127
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
125128
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
126129

127-
def forward(self, x_v, x_h):
130+
def forward(self, x_v, x_h, h):
128131
if self.mask_type == 'A':
129132
self.make_causal()
130133

134+
h = self.class_cond_embedding(h)
131135
h_vert = self.vert_stack(x_v)
132136
h_vert = h_vert[:, :, :x_v.size(-1), :]
133-
out_v = self.gate(h_vert)
137+
out_v = self.gate(h_vert + h[:, :, None, None])
134138

135139
h_horiz = self.horiz_stack(x_h)
136140
h_horiz = h_horiz[:, :, :, :x_h.size(-2)]
137141
v2h = self.vert_to_horiz(h_vert)
138142

139-
out = self.gate(v2h + h_horiz)
143+
out = self.gate(v2h + h_horiz + h[:, :, None, None])
140144
if self.residual:
141145
out_h = self.horiz_resid(out) + x_h
142146
else:
@@ -174,25 +178,23 @@ def __init__(self, input_dim=256, dim=64, n_layers=15):
174178
nn.Conv2d(dim, input_dim, 1)
175179
)
176180

177-
def forward(self, x):
181+
def forward(self, x, label):
178182
shp = x.size() + (-1, )
179183
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
180184
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
181185

182186
x_v, x_h = (x, x)
183187
for i, layer in enumerate(self.layers):
184-
x_v, x_h = layer(x_v, x_h)
188+
x_v, x_h = layer(x_v, x_h, label)
185189

186190
return self.output_conv(x_h)
187191

188-
def generate(self, batch_size=64):
189-
x = Variable(
190-
torch.zeros(64, 8, 8).long()
191-
).cuda()
192+
def generate(self, label, shape=(8, 8), batch_size=64):
193+
x = torch.zeros(batch_size, *shape).long().cuda()
192194

193-
for i in range(8):
194-
for j in range(8):
195-
logits = self.forward(x)
195+
for i in range(shape[0]):
196+
for j in range(shape[1]):
197+
logits = self.forward(x, label)
196198
probs = F.softmax(logits[:, :, i, j], -1)
197199
x.data[:, i, j].copy_(
198200
probs.multinomial(1).squeeze().data

0 commit comments

Comments
 (0)