We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 172c802 commit fa859feCopy full SHA for fa859fe
main_miniimagenet.py
@@ -98,14 +98,19 @@ def main(args):
98
model = AutoEncoder(3, args.hidden_size, args.k).to(args.device)
99
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
100
101
+ # Generate the samples first once
102
+ reconstruction = generate_samples(fixed_images, model, args)
103
+ grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
104
+ writer.add_image('reconstruction', grid, 0)
105
+
106
best_loss = -1.
107
for epoch in range(args.num_epochs):
108
train(train_loader, model, optimizer, args, writer)
109
loss, _ = test(valid_loader, model, args, writer)
110
111
reconstruction = generate_samples(fixed_images, model, args)
112
grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
- writer.add_image('reconstruction', grid, epoch)
113
+ writer.add_image('reconstruction', grid, epoch + 1)
114
115
if (epoch == 0) or (loss < best_loss):
116
best_loss = loss
0 commit comments