Skip to content

Commit 09895b2

Browse files
committed
Add n_classes to PixelCNN
1 parent 9c27d3f commit 09895b2

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def forward(self, x_v, x_h, h):
212212

213213

214214
class GatedPixelCNN(nn.Module):
215-
def __init__(self, input_dim=256, dim=64, n_layers=15):
215+
def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10):
216216
super().__init__()
217217
self.dim = dim
218218

@@ -230,7 +230,7 @@ def __init__(self, input_dim=256, dim=64, n_layers=15):
230230
residual = False if i == 0 else True
231231

232232
self.layers.append(
233-
GatedMaskedConv2d(mask_type, dim, kernel, residual)
233+
GatedMaskedConv2d(mask_type, dim, kernel, residual, n_classes)
234234
)
235235

236236
# Add the output layer

prior_miniimagenet.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,6 @@ def test(data_loader, model, prior, args, writer):
5353

5454
return loss.item()
5555

56-
# def generate_samples(images, model, args):
57-
# with torch.no_grad():
58-
# images = images.to(args.device)
59-
# x_tilde, _, _ = model(images)
60-
# return x_tilde
61-
6256
def main(args):
6357
writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
6458
save_filename = './models/{0}/prior.pt'.format(args.output_folder)
@@ -101,18 +95,18 @@ def main(args):
10195
model.load_state_dict(state_dict)
10296
model.eval()
10397

104-
prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers).to(args.device)
98+
prior = GatedPixelCNN(args.k, args.hidden_size_prior,
99+
args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device)
105100
optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr)
106101

107102
best_loss = -1.
108103
for epoch in range(args.num_epochs):
109104
train(train_loader, model, prior, optimizer, args, writer)
105+
# The validation loss is not properly computed since
106+
# the classes in the train and valid splits of Mini-Imagenet
107+
# do not overlap.
110108
loss = test(valid_loader, model, prior, args, writer)
111109

112-
# reconstruction = generate_samples(fixed_images, model, args)
113-
# grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
114-
# writer.add_image('reconstruction', grid, epoch + 1)
115-
116110
if (epoch == 0) or (loss < best_loss):
117111
best_loss = loss
118112
with open(save_filename, 'wb') as f:

0 commit comments

Comments
 (0)