From cd1aeea142e1e84fa53d91a80559661b6682ebca Mon Sep 17 00:00:00 2001 From: Joost van Amersfoort Date: Tue, 23 Jul 2019 12:22:30 +0100 Subject: [PATCH] Set train/eval correctly on model I think you forgot to set eval/train mode in your train/test function. Since `VectorQuantizedVAE` contains batchnorm that is required to correct results. --- vqvae.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vqvae.py b/vqvae.py index 25221ec..cdaaa77 100644 --- a/vqvae.py +++ b/vqvae.py @@ -10,6 +10,8 @@ from tensorboardX import SummaryWriter def train(data_loader, model, optimizer, args, writer): + model.train() + for images, _ in data_loader: images = images.to(args.device) @@ -34,6 +36,8 @@ def train(data_loader, model, optimizer, args, writer): args.steps += 1 def test(data_loader, model, args, writer): + model.eval() + with torch.no_grad(): loss_recons, loss_vq = 0., 0. for images, _ in data_loader: