Skip to content

Commit 64f4b56

Browse files
cifar10 sort of works
1 parent 4a7db75 commit 64f4b56

File tree

3 files changed

+25
-19
lines changed

3 files changed

+25
-19
lines changed

main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
NUM_WORKERS = 4
1313
LR = 2e-4
1414
K = 512
15-
LAMDA = 0.25
15+
LAMDA = 1
1616
PRINT_INTERVAL = 100
1717
N_EPOCHS = 100
1818

1919

2020
preproc_transform = transforms.Compose([
2121
transforms.ToTensor(),
22-
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
22+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
2323
])
2424
train_loader = torch.utils.data.DataLoader(
2525
datasets.CIFAR10(
@@ -102,7 +102,7 @@ def generate_samples():
102102
x_tilde, _, _ = model(x)
103103

104104
x_cat = torch.cat([x, x_tilde], 0)
105-
images = x_cat.cpu().data
105+
images = (x_cat.cpu().data + 1) / 2
106106
save_image(images, './sample_cifar.png', nrow=8)
107107

108108

modules.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66

77
def to_scalar(arr):
88
if type(arr) == list:
9-
return [x.cpu().data.tolist()[0] for x in arr]
9+
return [x.item() for x in arr]
1010
else:
11-
return arr.cpu().data.tolist()[0]
11+
return arr.item()
1212

1313

1414
def weights_init(m):
1515
classname = m.__class__.__name__
1616
if classname.find('Conv') != -1:
17-
nn.init.xavier_uniform(m.weight.data)
17+
nn.init.xavier_uniform_(m.weight.data)
1818
m.bias.data.fill_(0)
1919

2020

@@ -25,7 +25,7 @@ def __init__(self, dim):
2525
nn.ReLU(True),
2626
nn.Conv2d(dim, dim, 3, 1, 1),
2727
nn.ReLU(True),
28-
nn.Conv2d(dim, dim, 1)
28+
nn.Conv2d(dim, dim, 1),
2929
)
3030

3131
def forward(self, x):
@@ -39,21 +39,22 @@ def __init__(self, K=512):
3939
nn.Conv2d(3, 256, 4, 2, 1),
4040
nn.ReLU(True),
4141
nn.Conv2d(256, 256, 4, 2, 1),
42-
nn.ReLU(True),
4342
ResBlock(256),
4443
ResBlock(256),
4544
)
4645

4746
self.embedding = nn.Embedding(K, 256)
48-
self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
47+
# self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
48+
self.embedding.weight.data.uniform_(-1./K, 1./K)
4949

5050
self.decoder = nn.Sequential(
5151
ResBlock(256),
5252
ResBlock(256),
53+
nn.ReLU(True),
5354
nn.ConvTranspose2d(256, 256, 4, 2, 1),
5455
nn.ReLU(True),
5556
nn.ConvTranspose2d(256, 3, 4, 2, 1),
56-
nn.Sigmoid()
57+
nn.Tanh()
5758
)
5859

5960
self.apply(weights_init)
@@ -145,7 +146,7 @@ def forward(self, x_v, x_h):
145146

146147

147148
class GatedPixelCNN(nn.Module):
148-
def __init__(self, input_dim=256, dim=64, n_layers=7):
149+
def __init__(self, input_dim=256, dim=64, n_layers=15):
149150
super().__init__()
150151
self.dim = 64
151152

pixelcnn.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@
1010

1111
BATCH_SIZE = 64
1212
NUM_WORKERS = 4
13-
LR = 1e-3
14-
K = 256
13+
LR = 3e-4
14+
K = 512
15+
DIM = 64
16+
N_LAYERS = 15
1517
PRINT_INTERVAL = 100
1618
N_EPOCHS = 100
19+
ALWAYS_SAVE = True
1720

1821

1922
preproc_transform = transforms.Compose([
2023
transforms.ToTensor(),
21-
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
24+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
2225
])
2326
train_loader = torch.utils.data.DataLoader(
2427
datasets.CIFAR10(
@@ -37,10 +40,12 @@
3740
)
3841

3942
autoencoder = AutoEncoder(K).cuda()
40-
autoencoder.load_state_dict(torch.load('best_autoencoder.pt'))
43+
autoencoder.load_state_dict(
44+
torch.load('best_autoencoder.pt')
45+
)
4146
autoencoder.eval()
4247

43-
model = GatedPixelCNN().cuda()
48+
model = GatedPixelCNN(K, DIM, N_LAYERS).cuda()
4449
criterion = nn.CrossEntropyLoss().cuda()
4550
opt = torch.optim.Adam(model.parameters(), lr=LR)
4651

@@ -104,7 +109,7 @@ def test():
104109
def generate_samples():
105110
latents = model.generate()
106111
x_tilde, _ = autoencoder.decode(latents)
107-
images = x_tilde.cpu().data
112+
images = (x_tilde.cpu().data + 1) / 2
108113
save_image(images, './sample_pixelcnn_cifar.png', nrow=8)
109114

110115

@@ -114,7 +119,7 @@ def generate_reconstructions():
114119
latents, _ = autoencoder.encode(x)
115120
x_tilde, _ = autoencoder.decode(latents)
116121
x_cat = torch.cat([x, x_tilde], 0)
117-
images = x_cat.cpu().data
122+
images = (x_cat.cpu().data + 1) / 2
118123
save_image(images, './sample_cifar.png', nrow=8)
119124

120125

@@ -126,7 +131,7 @@ def generate_reconstructions():
126131
train()
127132
cur_loss = test()
128133

129-
if cur_loss <= BEST_LOSS:
134+
if ALWAYS_SAVE or cur_loss <= BEST_LOSS:
130135
BEST_LOSS = cur_loss
131136
LAST_SAVED = epoch
132137

0 commit comments

Comments
 (0)