Skip to content

Commit 9aab18d

Browse files
Merge pull request #1 from ritheshkumar95/cifar
Pushing CIFAR branch into master and merging
2 parents 3a2ad71 + bc99680 commit 9aab18d

File tree

6 files changed

+419
-61
lines changed

6 files changed

+419
-61
lines changed

main.py

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,130 @@
22
import torch.nn.functional as F
33
from torchvision import datasets, transforms
44
from modules import AutoEncoder, to_scalar
5-
from torch.autograd import Variable
65
import numpy as np
76
from torchvision.utils import save_image
87
import time
98

109

11-
kwargs = {'num_workers': 2, 'pin_memory': True}
12-
train_loader = torch.utils.data.DataLoader(
13-
datasets.FashionMNIST(
14-
'data/FashionMNIST/', train=True, download=True,
15-
transform=transforms.ToTensor()
16-
), batch_size=64, shuffle=False, **kwargs
17-
)
10+
BATCH_SIZE = 128
11+
N_EPOCHS = 100
12+
PRINT_INTERVAL = 100
13+
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
14+
NUM_WORKERS = 4
15+
16+
INPUT_DIM = 1 # 3 (RGB) | 1 (Grayscale)
17+
DIM = 256
18+
K = 512
19+
LAMDA = 1
20+
LR = 2e-4
21+
1822

23+
preproc_transform = transforms.Compose([
24+
transforms.ToTensor(),
25+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
26+
])
27+
train_loader = torch.utils.data.DataLoader(
28+
eval('datasets.'+DATASET)(
29+
'../data/{}/'.format(DATASET), train=True, download=True,
30+
transform=preproc_transform,
31+
), batch_size=BATCH_SIZE, shuffle=False,
32+
num_workers=NUM_WORKERS, pin_memory=True
33+
)
1934
test_loader = torch.utils.data.DataLoader(
20-
datasets.FashionMNIST(
21-
'data/FashionMNIST/', train=False,
22-
transform=transforms.ToTensor()
23-
), batch_size=32, shuffle=False, **kwargs
35+
eval('datasets.'+DATASET)(
36+
'../data/{}/'.format(DATASET), train=False,
37+
transform=preproc_transform
38+
), batch_size=BATCH_SIZE, shuffle=False,
39+
num_workers=NUM_WORKERS, pin_memory=True
2440
)
25-
test_data = list(test_loader)
2641

27-
model = AutoEncoder().cuda()
28-
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
42+
model = AutoEncoder(INPUT_DIM, DIM, K).cuda()
43+
opt = torch.optim.Adam(model.parameters(), lr=LR)
2944

3045

31-
def train(epoch):
46+
def train():
3247
train_loss = []
33-
for batch_idx, (data, _) in enumerate(train_loader):
48+
for batch_idx, (x, _) in enumerate(train_loader):
3449
start_time = time.time()
35-
x = Variable(data, requires_grad=False).cuda()
50+
x = x.cuda()
3651

3752
opt.zero_grad()
3853

3954
x_tilde, z_e_x, z_q_x = model(x)
4055
z_q_x.retain_grad()
4156

42-
loss_recons = F.binary_cross_entropy(x_tilde, x)
57+
loss_recons = F.mse_loss(x_tilde, x)
4358
loss_recons.backward(retain_graph=True)
4459

4560
# Straight-through estimator
4661
z_e_x.backward(z_q_x.grad, retain_graph=True)
4762

4863
# Vector quantization objective
64+
model.embedding.zero_grad()
4965
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
5066
loss_vq.backward(retain_graph=True)
5167

5268
# Commitment objective
53-
loss_commit = 0.25 * F.mse_loss(z_e_x, z_q_x.detach())
69+
loss_commit = LAMDA * F.mse_loss(z_e_x, z_q_x.detach())
5470
loss_commit.backward()
5571
opt.step()
5672

5773
train_loss.append(to_scalar([loss_recons, loss_vq]))
5874

59-
if (batch_idx + 1) % 100 == 0:
60-
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
61-
epoch, batch_idx * len(data), len(train_loader.dataset),
62-
100. * batch_idx / len(train_loader),
63-
np.asarray(train_loss)[-100:].mean(0),
75+
if (batch_idx + 1) % PRINT_INTERVAL == 0:
76+
print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
77+
batch_idx * len(x), len(train_loader.dataset),
78+
PRINT_INTERVAL * batch_idx / len(train_loader),
79+
np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
6480
time.time() - start_time
6581
))
6682

6783

6884
def test():
69-
x = Variable(test_data[0][0]).cuda()
85+
start_time = time.time()
86+
val_loss = []
87+
for batch_idx, (x, _) in enumerate(test_loader):
88+
x = x.cuda()
89+
x_tilde, z_e_x, z_q_x = model(x)
90+
loss_recons = F.mse_loss(x_tilde, x)
91+
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
92+
val_loss.append(to_scalar([loss_recons, loss_vq]))
93+
94+
print('\nValidation Completed!\tLoss: {} Time: {:5.3f}'.format(
95+
np.asarray(val_loss).mean(0),
96+
time.time() - start_time
97+
))
98+
return np.asarray(val_loss).mean(0)
99+
100+
101+
def generate_samples():
102+
x, _ = test_loader.__iter__().next()
103+
x = x[:32].cuda()
70104
x_tilde, _, _ = model(x)
71105

72106
x_cat = torch.cat([x, x_tilde], 0)
73-
images = x_cat.cpu().data
74-
save_image(images, './sample_fashion_mnist.png', nrow=8)
107+
images = (x_cat.cpu().data + 1) / 2
108+
109+
save_image(
110+
images,
111+
'samples/reconstructions_{}.png'.format(DATASET),
112+
nrow=8
113+
)
114+
115+
116+
BEST_LOSS = 999
117+
LAST_SAVED = -1
118+
for epoch in range(1, N_EPOCHS):
119+
print("Epoch {}:".format(epoch))
120+
train()
121+
cur_loss, _ = test()
75122

123+
if cur_loss <= BEST_LOSS:
124+
BEST_LOSS = cur_loss
125+
LAST_SAVED = epoch
126+
print("Saving model!")
127+
torch.save(model.state_dict(), 'models/{}_autoencoder.pt'.format(DATASET))
128+
else:
129+
print("Not saving model! Last saved: {}".format(LAST_SAVED))
76130

77-
for i in range(100):
78-
train(i)
79-
test()
131+
generate_samples()

modules.py

Lines changed: 173 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,202 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34

45

56
def to_scalar(arr):
67
if type(arr) == list:
7-
return [x.cpu().data.tolist()[0] for x in arr]
8+
return [x.item() for x in arr]
89
else:
9-
return arr.cpu().data.tolist()[0]
10+
return arr.item()
1011

1112

12-
def euclidean_distance(z_e_x, emb):
13-
dists = torch.pow(
14-
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
15-
2
16-
).sum(2)
17-
return dists
13+
def weights_init(m):
14+
classname = m.__class__.__name__
15+
if classname.find('Conv') != -1:
16+
nn.init.xavier_uniform_(m.weight.data)
17+
m.bias.data.fill_(0)
18+
19+
20+
class ResBlock(nn.Module):
21+
def __init__(self, dim):
22+
super(ResBlock, self).__init__()
23+
self.block = nn.Sequential(
24+
nn.ReLU(True),
25+
nn.Conv2d(dim, dim, 3, 1, 1),
26+
nn.ReLU(True),
27+
nn.Conv2d(dim, dim, 1),
28+
)
29+
30+
def forward(self, x):
31+
return x + self.block(x)
1832

1933

2034
class AutoEncoder(nn.Module):
21-
def __init__(self):
35+
def __init__(self, input_dim, dim, K=512):
2236
super(AutoEncoder, self).__init__()
2337
self.encoder = nn.Sequential(
24-
nn.Conv2d(1, 16, 4, 2, 1),
25-
nn.BatchNorm2d(16),
26-
nn.ReLU(True),
27-
nn.Conv2d(16, 32, 4, 2, 1),
28-
nn.BatchNorm2d(32),
38+
nn.Conv2d(input_dim, dim, 4, 2, 1),
2939
nn.ReLU(True),
30-
nn.Conv2d(32, 64, 1, 1, 0),
31-
nn.BatchNorm2d(64),
40+
nn.Conv2d(dim, dim, 4, 2, 1),
41+
ResBlock(dim),
42+
ResBlock(dim),
3243
)
3344

34-
self.embedding = nn.Embedding(512, 64)
45+
self.embedding = nn.Embedding(K, dim)
46+
# self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
47+
self.embedding.weight.data.uniform_(-1./K, 1./K)
3548

3649
self.decoder = nn.Sequential(
37-
nn.Conv2d(64, 32, 1, 1, 0),
38-
nn.BatchNorm2d(32),
50+
ResBlock(dim),
51+
ResBlock(dim),
3952
nn.ReLU(True),
40-
nn.ConvTranspose2d(32, 16, 4, 2, 1),
41-
nn.BatchNorm2d(16),
53+
nn.ConvTranspose2d(dim, dim, 4, 2, 1),
4254
nn.ReLU(True),
43-
nn.ConvTranspose2d(16, 1, 4, 2, 1),
44-
nn.Sigmoid()
55+
nn.ConvTranspose2d(dim, input_dim, 4, 2, 1),
56+
nn.Tanh()
4557
)
4658

47-
def forward(self, x):
59+
self.apply(weights_init)
60+
61+
def encode(self, x):
4862
z_e_x = self.encoder(x)
49-
B, C, H, W = z_e_x.size()
5063

51-
dists = euclidean_distance(z_e_x, self.embedding.weight)
52-
latents = dists.min(1)[1]
64+
z_e_x_transp = z_e_x.permute(0, 2, 3, 1) # (B, H, W, C)
65+
emb = self.embedding.weight.transpose(0, 1) # (C, K)
66+
dists = torch.pow(
67+
z_e_x_transp.unsqueeze(4) - emb[None, None, None],
68+
2
69+
).sum(-2)
70+
latents = dists.min(-1)[1]
71+
return latents, z_e_x
5372

73+
def decode(self, latents):
5474
shp = latents.size() + (-1, )
55-
z_q_x = self.embedding(latents.view(-1)).view(*shp)
56-
z_q_x = z_q_x.permute(0, 3, 1, 2)
57-
75+
z_q_x = self.embedding(latents.view(latents.size(0), -1)) # (B * H * W, C)
76+
z_q_x = z_q_x.view(*shp).permute(0, 3, 1, 2) # (B, C, H, W)
5877
x_tilde = self.decoder(z_q_x)
78+
return x_tilde, z_q_x
79+
80+
def forward(self, x):
81+
latents, z_e_x = self.encode(x)
82+
x_tilde, z_q_x = self.decode(latents)
5983
return x_tilde, z_e_x, z_q_x
84+
85+
86+
class GatedActivation(nn.Module):
87+
def __init__(self):
88+
super().__init__()
89+
90+
def forward(self, x):
91+
x, y = x.chunk(2, dim=1)
92+
return F.tanh(x) * F.sigmoid(y)
93+
94+
95+
class GatedMaskedConv2d(nn.Module):
96+
def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10):
97+
super().__init__()
98+
assert kernel % 2 == 1, print("Kernel size must be odd")
99+
self.mask_type = mask_type
100+
self.residual = residual
101+
102+
self.class_cond_embedding = nn.Embedding(
103+
n_classes, 2 * dim
104+
)
105+
106+
kernel_shp = (kernel // 2 + 1, kernel) # (ceil(n/2), n)
107+
padding_shp = (kernel // 2, kernel // 2)
108+
self.vert_stack = nn.Conv2d(
109+
dim, dim * 2,
110+
kernel_shp, 1, padding_shp
111+
)
112+
113+
self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
114+
115+
kernel_shp = (1, kernel // 2 + 1)
116+
padding_shp = (0, kernel // 2)
117+
self.horiz_stack = nn.Conv2d(
118+
dim, dim * 2,
119+
kernel_shp, 1, padding_shp
120+
)
121+
122+
self.horiz_resid = nn.Conv2d(dim, dim, 1)
123+
124+
self.gate = GatedActivation()
125+
126+
def make_causal(self):
127+
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
128+
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
129+
130+
def forward(self, x_v, x_h, h):
131+
if self.mask_type == 'A':
132+
self.make_causal()
133+
134+
h = self.class_cond_embedding(h)
135+
h_vert = self.vert_stack(x_v)
136+
h_vert = h_vert[:, :, :x_v.size(-1), :]
137+
out_v = self.gate(h_vert + h[:, :, None, None])
138+
139+
h_horiz = self.horiz_stack(x_h)
140+
h_horiz = h_horiz[:, :, :, :x_h.size(-2)]
141+
v2h = self.vert_to_horiz(h_vert)
142+
143+
out = self.gate(v2h + h_horiz + h[:, :, None, None])
144+
if self.residual:
145+
out_h = self.horiz_resid(out) + x_h
146+
else:
147+
out_h = self.horiz_resid(out)
148+
149+
return out_v, out_h
150+
151+
152+
class GatedPixelCNN(nn.Module):
153+
def __init__(self, input_dim=256, dim=64, n_layers=15):
154+
super().__init__()
155+
self.dim = 64
156+
157+
# Create embedding layer to embed input
158+
self.embedding = nn.Embedding(input_dim, dim)
159+
160+
# Building the PixelCNN layer by layer
161+
self.layers = nn.ModuleList()
162+
163+
# Initial block with Mask-A convolution
164+
# Rest with Mask-B convolutions
165+
for i in range(n_layers):
166+
mask_type = 'A' if i == 0 else 'B'
167+
kernel = 7 if i == 0 else 3
168+
residual = False if i == 0 else True
169+
170+
self.layers.append(
171+
GatedMaskedConv2d(mask_type, dim, kernel, residual)
172+
)
173+
174+
# Add the output layer
175+
self.output_conv = nn.Sequential(
176+
nn.Conv2d(dim, dim, 1),
177+
nn.ReLU(True),
178+
nn.Conv2d(dim, input_dim, 1)
179+
)
180+
181+
def forward(self, x, label):
182+
shp = x.size() + (-1, )
183+
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
184+
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
185+
186+
x_v, x_h = (x, x)
187+
for i, layer in enumerate(self.layers):
188+
x_v, x_h = layer(x_v, x_h, label)
189+
190+
return self.output_conv(x_h)
191+
192+
def generate(self, label, shape=(8, 8), batch_size=64):
193+
x = torch.zeros(batch_size, *shape).long().cuda()
194+
195+
for i in range(shape[0]):
196+
for j in range(shape[1]):
197+
logits = self.forward(x, label)
198+
probs = F.softmax(logits[:, :, i, j], -1)
199+
x.data[:, i, j].copy_(
200+
probs.multinomial(1).squeeze().data
201+
)
202+
return x

0 commit comments

Comments
 (0)