@@ -53,12 +53,6 @@ def test(data_loader, model, prior, args, writer):
53
53
54
54
return loss .item ()
55
55
56
- # def generate_samples(images, model, args):
57
- # with torch.no_grad():
58
- # images = images.to(args.device)
59
- # x_tilde, _, _ = model(images)
60
- # return x_tilde
61
-
62
56
def main (args ):
63
57
writer = SummaryWriter ('./logs/{0}' .format (args .output_folder ))
64
58
save_filename = './models/{0}/prior.pt' .format (args .output_folder )
@@ -101,18 +95,18 @@ def main(args):
101
95
model .load_state_dict (state_dict )
102
96
model .eval ()
103
97
104
- prior = GatedPixelCNN (args .k , args .hidden_size_prior , args .num_layers ).to (args .device )
98
+ prior = GatedPixelCNN (args .k , args .hidden_size_prior ,
99
+ args .num_layers , n_classes = len (train_dataset ._label_encoder )).to (args .device )
105
100
optimizer = torch .optim .Adam (prior .parameters (), lr = args .lr )
106
101
107
102
best_loss = - 1.
108
103
for epoch in range (args .num_epochs ):
109
104
train (train_loader , model , prior , optimizer , args , writer )
105
+ # The validation loss is not properly computed since
106
+ # the classes in the train and valid splits of Mini-Imagenet
107
+ # do not overlap.
110
108
loss = test (valid_loader , model , prior , args , writer )
111
109
112
- # reconstruction = generate_samples(fixed_images, model, args)
113
- # grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
114
- # writer.add_image('reconstruction', grid, epoch + 1)
115
-
116
110
if (epoch == 0 ) or (loss < best_loss ):
117
111
best_loss = loss
118
112
with open (save_filename , 'wb' ) as f :
0 commit comments