Skip to content

Commit 0ee160e

Browse files
shifting to GatedPixelCNN
1 parent 4ee173d commit 0ee160e

File tree

3 files changed

+97
-35
lines changed

3 files changed

+97
-35
lines changed

main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def generate_samples():
100100
x, _ = test_loader.__iter__().next()
101101
x = Variable(x[:32]).cuda()
102102
x_tilde, _, _ = model(x)
103-
# x_tilde = (x_tilde + 1)/2
104-
# x = (x + 1)/2
105103

106104
x_cat = torch.cat([x, x_tilde], 0)
107105
images = x_cat.cpu().data

modules.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -83,52 +83,106 @@ def forward(self, x):
8383
return x_tilde, z_e_x, z_q_x
8484

8585

86-
class MaskedConv2d(nn.Conv2d):
87-
def __init__(self, mask_type, *args, **kwargs):
88-
super(MaskedConv2d, self).__init__(*args, **kwargs)
89-
assert mask_type in {'A', 'B'}
90-
self.register_buffer('mask', self.weight.data.clone())
91-
_, _, kH, kW = self.weight.size()
92-
self.mask.fill_(1)
93-
self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
94-
self.mask[:, :, kH // 2 + 1:] = 0
86+
class GatedActivation(nn.Module):
87+
def __init__(self):
88+
super().__init__()
9589

9690
def forward(self, x):
97-
self.weight.data *= self.mask
98-
return super(MaskedConv2d, self).forward(x)
91+
x, y = x.chunk(2, dim=1)
92+
return F.tanh(x) * F.sigmoid(y)
93+
94+
95+
class GatedMaskedConv2d(nn.Module):
96+
def __init__(self, mask_type, dim, kernel, residual=True):
97+
super().__init__()
98+
assert kernel % 2 == 1, print("Kernel size must be odd")
99+
self.mask_type = mask_type
100+
self.residual = residual
101+
102+
kernel_shp = (kernel // 2 + 1, kernel) # (ceil(n/2), n)
103+
padding_shp = (kernel // 2, kernel // 2)
104+
self.vert_stack = nn.Conv2d(
105+
dim, dim * 2,
106+
kernel_shp, 1, padding_shp
107+
)
108+
109+
self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1)
110+
111+
kernel_shp = (1, kernel // 2 + 1)
112+
padding_shp = (0, kernel // 2)
113+
self.horiz_stack = nn.Conv2d(
114+
dim, dim * 2,
115+
kernel_shp, 1, padding_shp
116+
)
117+
118+
self.horiz_resid = nn.Conv2d(dim, dim, 1)
119+
120+
self.gate = GatedActivation()
121+
122+
def make_causal(self):
123+
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row
124+
self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column
99125

126+
def forward(self, x_v, x_h):
127+
if self.mask_type == 'A':
128+
self.make_causal()
100129

101-
class PixelCNN(nn.Module):
102-
def __init__(self, dim=64, n_layers=4):
130+
h_vert = self.vert_stack(x_v)
131+
h_vert = h_vert[:, :, :x_v.size(-1), :]
132+
out_v = self.gate(h_vert)
133+
134+
h_horiz = self.horiz_stack(x_h)
135+
h_horiz = h_horiz[:, :, :, :x_h.size(-2)]
136+
v2h = self.vert_to_horiz(h_vert)
137+
138+
out = self.gate(v2h + h_horiz)
139+
if self.residual:
140+
out_h = self.horiz_resid(out) + x_h
141+
else:
142+
out_h = self.horiz_resid(out)
143+
144+
return out_v, out_h
145+
146+
147+
class GatedPixelCNN(nn.Module):
148+
def __init__(self, input_dim=256, dim=64, n_layers=7):
103149
super().__init__()
104150
self.dim = 64
105151

106152
# Create embedding layer to embed input
107-
self.embedding = nn.Embedding(256, dim)
153+
self.embedding = nn.Embedding(input_dim, dim)
108154

109155
# Building the PixelCNN layer by layer
110-
net = []
156+
self.layers = nn.ModuleList()
111157

112158
# Initial block with Mask-A convolution
113159
# Rest with Mask-B convolutions
114160
for i in range(n_layers):
115161
mask_type = 'A' if i == 0 else 'B'
116-
net.extend([
117-
MaskedConv2d(mask_type, dim, dim, 7, 1, 3, bias=False),
118-
nn.BatchNorm2d(dim),
119-
nn.ReLU(True)
120-
])
162+
kernel = 7 if i == 0 else 3
163+
residual = False if i == 0 else True
121164

122-
# Add the output layer
123-
net.append(nn.Conv2d(dim, 256, 1))
165+
self.layers.append(
166+
GatedMaskedConv2d(mask_type, dim, kernel, residual)
167+
)
124168

125-
self.net = nn.Sequential(*net)
169+
# Add the output layer
170+
self.output_conv = nn.Sequential(
171+
nn.Conv2d(dim, dim, 1),
172+
nn.ReLU(True),
173+
nn.Conv2d(dim, input_dim, 1)
174+
)
126175

127176
def forward(self, x):
128177
shp = x.size() + (-1, )
129178
x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C)
130179
x = x.permute(0, 3, 1, 2) # (B, C, W, W)
131-
return self.net(x)
180+
181+
x_v, x_h = (x, x)
182+
for i, layer in enumerate(self.layers):
183+
x_v, x_h = layer(x_v, x_h)
184+
185+
return self.output_conv(x_h)
132186

133187
def generate(self, batch_size=64):
134188
x = Variable(

pixelcnn.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from torchvision import datasets, transforms
4-
from modules import AutoEncoder, PixelCNN, to_scalar
4+
from modules import AutoEncoder, GatedPixelCNN, to_scalar
55
from torch.autograd import Variable
66
import numpy as np
77
from torchvision.utils import save_image
@@ -16,22 +16,22 @@
1616
N_EPOCHS = 100
1717

1818

19+
preproc_transform = transforms.Compose([
20+
transforms.ToTensor(),
21+
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
22+
])
1923
train_loader = torch.utils.data.DataLoader(
2024
datasets.CIFAR10(
2125
'../data/cifar10/', train=True, download=True,
22-
transform=transforms.Compose(
23-
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
24-
)
26+
transform=preproc_transform,
2527
), batch_size=BATCH_SIZE, shuffle=False,
2628
num_workers=NUM_WORKERS, pin_memory=True
2729
)
2830

2931
test_loader = torch.utils.data.DataLoader(
3032
datasets.CIFAR10(
3133
'../data/cifar10/', train=False,
32-
transform=transforms.Compose(
33-
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
34-
)
34+
transform=preproc_transform
3535
), batch_size=BATCH_SIZE, shuffle=False,
3636
num_workers=NUM_WORKERS, pin_memory=True
3737
)
@@ -40,7 +40,7 @@
4040
autoencoder.load_state_dict(torch.load('best_autoencoder.pt'))
4141
autoencoder.eval()
4242

43-
model = PixelCNN().cuda()
43+
model = GatedPixelCNN().cuda()
4444
criterion = nn.CrossEntropyLoss().cuda()
4545
opt = torch.optim.Adam(model.parameters(), lr=LR)
4646

@@ -104,13 +104,23 @@ def test():
104104
def generate_samples():
105105
latents = model.generate()
106106
x_tilde, _ = autoencoder.decode(latents)
107-
# images = ((x_tilde + 1) / 2).cpu().data
108107
images = x_tilde.cpu().data
109108
save_image(images, './sample_pixelcnn_cifar.png', nrow=8)
110109

111110

111+
def generate_reconstructions():
112+
x, _ = test_loader.__iter__().next()
113+
x = Variable(x[:32]).cuda()
114+
latents, _ = autoencoder.encode(x)
115+
x_tilde, _ = autoencoder.decode(latents)
116+
x_cat = torch.cat([x, x_tilde], 0)
117+
images = x_cat.cpu().data
118+
save_image(images, './sample_cifar.png', nrow=8)
119+
120+
112121
BEST_LOSS = 999
113122
LAST_SAVED = -1
123+
generate_reconstructions()
114124
for epoch in range(1, N_EPOCHS):
115125
print("\nEpoch {}:".format(epoch))
116126
train()

0 commit comments

Comments
 (0)