Skip to content

Commit 9c27d3f

Browse files
committed
Save label encoder for PixelCNN
1 parent 50e35d1 commit 9c27d3f

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

prior_miniimagenet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33
import torch.nn.functional as F
4+
import json
45
from torchvision import transforms
56
from torchvision.utils import save_image, make_grid
67

@@ -85,6 +86,10 @@ def main(args):
8586
test_loader = torch.utils.data.DataLoader(test_dataset,
8687
batch_size=16, shuffle=True)
8788

89+
# Save the label encoder
90+
with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
91+
json.dump(train_dataset._label_encoder, f)
92+
8893
# Fixed images for Tensorboard
8994
fixed_images, _ = next(iter(test_loader))
9095
fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)

0 commit comments

Comments
 (0)