Skip to content

Commit 3874acb

Browse files
adding pixelcnn baseline on images
1 parent 9aab18d commit 3874acb

File tree

4 files changed

+138
-9
lines changed

4 files changed

+138
-9
lines changed

main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
import time
88

99

10-
BATCH_SIZE = 128
10+
BATCH_SIZE = 32
1111
N_EPOCHS = 100
1212
PRINT_INTERVAL = 100
13-
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
13+
DATASET = 'CIFAR10' # CIFAR10 | MNIST | FashionMNIST
1414
NUM_WORKERS = 4
1515

16-
INPUT_DIM = 1 # 3 (RGB) | 1 (Grayscale)
16+
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
1717
DIM = 256
1818
K = 512
1919
LAMDA = 1
20-
LR = 2e-4
20+
LR = 3e-4
2121

2222

2323
preproc_transform = transforms.Compose([

pixelcnn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
import time
88

99

10-
BATCH_SIZE = 64
10+
BATCH_SIZE = 32
1111
N_EPOCHS = 100
1212
PRINT_INTERVAL = 100
1313
ALWAYS_SAVE = True
14-
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
14+
DATASET = 'CIFAR10' # CIFAR10 | MNIST | FashionMNIST
1515
NUM_WORKERS = 4
1616

17-
LATENT_SHAPE = (7, 7) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18-
INPUT_DIM = 1 # 3 (RGB) | 1 (Grayscale)
19-
DIM = 64
17+
LATENT_SHAPE = (8, 8) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18+
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
19+
DIM = 128
2020
VAE_DIM = 256
2121
N_LAYERS = 15
2222
K = 512

pixelcnn_baseline.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
import torch.nn as nn
3+
from torchvision import datasets, transforms
4+
from modules import GatedPixelCNN
5+
import numpy as np
6+
from torchvision.utils import save_image
7+
import time
8+
9+
10+
BATCH_SIZE = 32
11+
N_EPOCHS = 100
12+
PRINT_INTERVAL = 100
13+
ALWAYS_SAVE = True
14+
DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST
15+
NUM_WORKERS = 4
16+
17+
LATENT_SHAPE = (28, 28) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images
18+
INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale)
19+
K = 256
20+
DIM = 64
21+
N_LAYERS = 15
22+
LR = 3e-4
23+
24+
25+
train_loader = torch.utils.data.DataLoader(
26+
eval('datasets.'+DATASET)(
27+
'../data/{}/'.format(DATASET), train=True, download=True,
28+
transform=transforms.ToTensor(),
29+
), batch_size=BATCH_SIZE, shuffle=False,
30+
num_workers=NUM_WORKERS, pin_memory=True
31+
)
32+
test_loader = torch.utils.data.DataLoader(
33+
eval('datasets.'+DATASET)(
34+
'../data/{}/'.format(DATASET), train=False,
35+
transform=transforms.ToTensor(),
36+
), batch_size=BATCH_SIZE, shuffle=False,
37+
num_workers=NUM_WORKERS, pin_memory=True
38+
)
39+
40+
model = GatedPixelCNN(K, DIM, N_LAYERS).cuda()
41+
criterion = nn.CrossEntropyLoss().cuda()
42+
opt = torch.optim.Adam(model.parameters(), lr=LR)
43+
44+
45+
def train():
46+
train_loss = []
47+
for batch_idx, (x, label) in enumerate(train_loader):
48+
start_time = time.time()
49+
x = (x[:, 0] * (K-1)).long().cuda()
50+
label = label.cuda()
51+
52+
# Train PixelCNN with images
53+
logits = model(x, label)
54+
logits = logits.permute(0, 2, 3, 1).contiguous()
55+
56+
loss = criterion(
57+
logits.view(-1, K),
58+
x.view(-1)
59+
)
60+
61+
opt.zero_grad()
62+
loss.backward()
63+
opt.step()
64+
65+
train_loss.append(loss.item())
66+
67+
if (batch_idx + 1) % PRINT_INTERVAL == 0:
68+
print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
69+
batch_idx * len(x), len(train_loader.dataset),
70+
PRINT_INTERVAL * batch_idx / len(train_loader),
71+
np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
72+
time.time() - start_time
73+
))
74+
75+
76+
def test():
77+
start_time = time.time()
78+
val_loss = []
79+
with torch.no_grad():
80+
for batch_idx, (x, label) in enumerate(test_loader):
81+
x = (x[:, 0] * (K-1)).long().cuda()
82+
label = label.cuda()
83+
84+
logits = model(x, label)
85+
logits = logits.permute(0, 2, 3, 1).contiguous()
86+
loss = criterion(
87+
logits.view(-1, K),
88+
x.view(-1)
89+
)
90+
val_loss.append(loss.item())
91+
92+
print('Validation Completed!\tLoss: {} Time: {}'.format(
93+
np.asarray(val_loss).mean(0),
94+
time.time() - start_time
95+
))
96+
return np.asarray(val_loss).mean(0)
97+
98+
99+
def generate_samples():
100+
label = torch.arange(10).expand(10, 10).contiguous().view(-1)
101+
label = label.long().cuda()
102+
103+
x_tilde = model.generate(label, shape=LATENT_SHAPE, batch_size=100)
104+
images = x_tilde.cpu().data.float() / (K - 1)
105+
106+
save_image(
107+
images[:, None],
108+
'samples/pixelcnn_baseline_samples_{}.png'.format(DATASET),
109+
nrow=10
110+
)
111+
112+
113+
BEST_LOSS = 999
114+
LAST_SAVED = -1
115+
for epoch in range(1, N_EPOCHS):
116+
print("\nEpoch {}:".format(epoch))
117+
train()
118+
cur_loss = test()
119+
120+
if ALWAYS_SAVE or cur_loss <= BEST_LOSS:
121+
BEST_LOSS = cur_loss
122+
LAST_SAVED = epoch
123+
124+
print("Saving model!")
125+
torch.save(model.state_dict(), 'models/{}_pixelcnn.pt'.format(DATASET))
126+
else:
127+
print("Not saving model! Last saved: {}".format(LAST_SAVED))
128+
129+
generate_samples()

samples/samples_FashionMNIST.png

-748 Bytes
Loading

0 commit comments

Comments
 (0)