Skip to content

Commit ac6284c

Browse files
committed
Change AutoEncoder name to VectorQuantizedVAE
1 parent cde1426 commit ac6284c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

miniimagenet_pixelcnn_prior.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision import transforms
66
from torchvision.utils import save_image, make_grid
77

8-
from modules import AutoEncoder, GatedPixelCNN
8+
from modules import VectorQuantizedVAE, GatedPixelCNN
99
from datasets import MiniImagenet
1010

1111
from tensorboardX import SummaryWriter
@@ -89,7 +89,7 @@ def main(args):
8989
fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
9090
writer.add_image('original', fixed_grid, 0)
9191

92-
model = AutoEncoder(3, args.hidden_size_vae, args.k).to(args.device)
92+
model = VectorQuantizedVAE(3, args.hidden_size_vae, args.k).to(args.device)
9393
with open(args.model, 'rb') as f:
9494
state_dict = torch.load(f)
9595
model.load_state_dict(state_dict)

miniimagenet_vqvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torchvision import transforms
55
from torchvision.utils import save_image, make_grid
66

7-
from modules import AutoEncoder, to_scalar
7+
from modules import VectorQuantizedVAE, to_scalar
88
from datasets import MiniImagenet
99

1010
from tensorboardX import SummaryWriter
@@ -95,7 +95,7 @@ def main(args):
9595
fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
9696
writer.add_image('original', fixed_grid, 0)
9797

98-
model = AutoEncoder(3, args.hidden_size, args.k).to(args.device)
98+
model = VectorQuantizedVAE(3, args.hidden_size, args.k).to(args.device)
9999
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
100100

101101
# Generate the samples first once

0 commit comments

Comments
 (0)