diff --git a/vqvae.py b/vqvae.py index 25221ec..cdaaa77 100644 --- a/vqvae.py +++ b/vqvae.py @@ -10,6 +10,8 @@ from tensorboardX import SummaryWriter def train(data_loader, model, optimizer, args, writer): + model.train() + for images, _ in data_loader: images = images.to(args.device) @@ -34,6 +36,8 @@ def train(data_loader, model, optimizer, args, writer): args.steps += 1 def test(data_loader, model, args, writer): + model.eval() + with torch.no_grad(): loss_recons, loss_vq = 0., 0. for images, _ in data_loader: