Skip to content

Commit 3a2ad71

Browse files
pushing initial code
0 parents  commit 3a2ad71

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

main.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from torchvision import datasets, transforms
4+
from modules import AutoEncoder, to_scalar
5+
from torch.autograd import Variable
6+
import numpy as np
7+
from torchvision.utils import save_image
8+
import time
9+
10+
11+
kwargs = {'num_workers': 2, 'pin_memory': True}
12+
train_loader = torch.utils.data.DataLoader(
13+
datasets.FashionMNIST(
14+
'data/FashionMNIST/', train=True, download=True,
15+
transform=transforms.ToTensor()
16+
), batch_size=64, shuffle=False, **kwargs
17+
)
18+
19+
test_loader = torch.utils.data.DataLoader(
20+
datasets.FashionMNIST(
21+
'data/FashionMNIST/', train=False,
22+
transform=transforms.ToTensor()
23+
), batch_size=32, shuffle=False, **kwargs
24+
)
25+
test_data = list(test_loader)
26+
27+
model = AutoEncoder().cuda()
28+
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
29+
30+
31+
def train(epoch):
32+
train_loss = []
33+
for batch_idx, (data, _) in enumerate(train_loader):
34+
start_time = time.time()
35+
x = Variable(data, requires_grad=False).cuda()
36+
37+
opt.zero_grad()
38+
39+
x_tilde, z_e_x, z_q_x = model(x)
40+
z_q_x.retain_grad()
41+
42+
loss_recons = F.binary_cross_entropy(x_tilde, x)
43+
loss_recons.backward(retain_graph=True)
44+
45+
# Straight-through estimator
46+
z_e_x.backward(z_q_x.grad, retain_graph=True)
47+
48+
# Vector quantization objective
49+
loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
50+
loss_vq.backward(retain_graph=True)
51+
52+
# Commitment objective
53+
loss_commit = 0.25 * F.mse_loss(z_e_x, z_q_x.detach())
54+
loss_commit.backward()
55+
opt.step()
56+
57+
train_loss.append(to_scalar([loss_recons, loss_vq]))
58+
59+
if (batch_idx + 1) % 100 == 0:
60+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
61+
epoch, batch_idx * len(data), len(train_loader.dataset),
62+
100. * batch_idx / len(train_loader),
63+
np.asarray(train_loss)[-100:].mean(0),
64+
time.time() - start_time
65+
))
66+
67+
68+
def test():
69+
x = Variable(test_data[0][0]).cuda()
70+
x_tilde, _, _ = model(x)
71+
72+
x_cat = torch.cat([x, x_tilde], 0)
73+
images = x_cat.cpu().data
74+
save_image(images, './sample_fashion_mnist.png', nrow=8)
75+
76+
77+
for i in range(100):
78+
train(i)
79+
test()

modules.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def to_scalar(arr):
6+
if type(arr) == list:
7+
return [x.cpu().data.tolist()[0] for x in arr]
8+
else:
9+
return arr.cpu().data.tolist()[0]
10+
11+
12+
def euclidean_distance(z_e_x, emb):
13+
dists = torch.pow(
14+
z_e_x.unsqueeze(1) - emb[None, :, :, None, None],
15+
2
16+
).sum(2)
17+
return dists
18+
19+
20+
class AutoEncoder(nn.Module):
21+
def __init__(self):
22+
super(AutoEncoder, self).__init__()
23+
self.encoder = nn.Sequential(
24+
nn.Conv2d(1, 16, 4, 2, 1),
25+
nn.BatchNorm2d(16),
26+
nn.ReLU(True),
27+
nn.Conv2d(16, 32, 4, 2, 1),
28+
nn.BatchNorm2d(32),
29+
nn.ReLU(True),
30+
nn.Conv2d(32, 64, 1, 1, 0),
31+
nn.BatchNorm2d(64),
32+
)
33+
34+
self.embedding = nn.Embedding(512, 64)
35+
36+
self.decoder = nn.Sequential(
37+
nn.Conv2d(64, 32, 1, 1, 0),
38+
nn.BatchNorm2d(32),
39+
nn.ReLU(True),
40+
nn.ConvTranspose2d(32, 16, 4, 2, 1),
41+
nn.BatchNorm2d(16),
42+
nn.ReLU(True),
43+
nn.ConvTranspose2d(16, 1, 4, 2, 1),
44+
nn.Sigmoid()
45+
)
46+
47+
def forward(self, x):
48+
z_e_x = self.encoder(x)
49+
B, C, H, W = z_e_x.size()
50+
51+
dists = euclidean_distance(z_e_x, self.embedding.weight)
52+
latents = dists.min(1)[1]
53+
54+
shp = latents.size() + (-1, )
55+
z_q_x = self.embedding(latents.view(-1)).view(*shp)
56+
z_q_x = z_q_x.permute(0, 3, 1, 2)
57+
58+
x_tilde = self.decoder(z_q_x)
59+
return x_tilde, z_e_x, z_q_x

sample_fashion_mnist.png

49.8 KB
Loading

0 commit comments

Comments
 (0)