Skip to content

Commit 2ba91d3

Browse files
committed
Add script to train the prior for mini-imagenet
1 parent 41ba39a commit 2ba91d3

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

prior_miniimagenet.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn.functional as F
4+
from torchvision import transforms
5+
from torchvision.utils import save_image, make_grid
6+
7+
from modules import AutoEncoder, GatedPixelCNN
8+
from datasets import MiniImagenet
9+
10+
from tensorboardX import SummaryWriter
11+
12+
def train(data_loader, model, prior, optimizer, args, writer):
13+
for images, labels in data_loader:
14+
with torch.no_grad():
15+
images = images.to(args.device)
16+
latents, _ = model.encode(images)
17+
latents = latents.detach()
18+
19+
labels = labels.to(args.device)
20+
logits = prior(latents, labels)
21+
logits = logits.permute(0, 2, 3, 1).contiguous()
22+
23+
optimizer.zero_grad()
24+
loss = F.cross_entropy(logits.view(-1, args.k),
25+
latents.view(-1))
26+
loss.backward()
27+
28+
# Logs
29+
writer.add_scalar('loss/train', loss.item(), args.steps)
30+
31+
optimizer.step()
32+
args.steps += 1
33+
34+
def test(data_loader, model, prior, args, writer):
35+
with torch.no_grad():
36+
loss = 0.
37+
for images, labels in data_loader:
38+
images = images.to(args.device)
39+
labels = labels.to(args.device)
40+
41+
latents, _ = model.encode(images)
42+
latents = latents.detach()
43+
logits = prior(latents, labels)
44+
logits = logits.permute(0, 2, 3, 1).contiguous()
45+
loss += F.cross_entropy(logits.view(-1, args.k),
46+
latents.view(-1))
47+
48+
loss /= len(data_loader)
49+
50+
# Logs
51+
writer.add_scalar('loss/valid', loss.item(), args.steps)
52+
53+
return loss.item()
54+
55+
# def generate_samples(images, model, args):
56+
# with torch.no_grad():
57+
# images = images.to(args.device)
58+
# x_tilde, _, _ = model(images)
59+
# return x_tilde
60+
61+
def main(args):
62+
writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
63+
save_filename = './models/{0}/prior.pt'.format(args.output_folder)
64+
65+
transform = transforms.Compose([
66+
transforms.RandomResizedCrop(128),
67+
transforms.ToTensor(),
68+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69+
])
70+
71+
# Define the train, valid & test datasets
72+
train_dataset = MiniImagenet(args.data_folder, train=True,
73+
download=True, transform=transform)
74+
valid_dataset = MiniImagenet(args.data_folder, valid=True,
75+
download=True, transform=transform)
76+
test_dataset = MiniImagenet(args.data_folder, test=True,
77+
download=True, transform=transform)
78+
# Define the data loaders
79+
train_loader = torch.utils.data.DataLoader(train_dataset,
80+
batch_size=args.batch_size, shuffle=False,
81+
num_workers=args.num_workers, pin_memory=True)
82+
valid_loader = torch.utils.data.DataLoader(valid_dataset,
83+
batch_size=args.batch_size, shuffle=False, drop_last=True,
84+
num_workers=args.num_workers, pin_memory=True)
85+
test_loader = torch.utils.data.DataLoader(test_dataset,
86+
batch_size=16, shuffle=True)
87+
88+
# Fixed images for Tensorboard
89+
fixed_images, _ = next(iter(test_loader))
90+
fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
91+
writer.add_image('original', fixed_grid, 0)
92+
93+
model = AutoEncoder(3, args.hidden_size_vae, args.k).to(args.device)
94+
with open(args.model, 'rb') as f:
95+
state_dict = torch.load(f)
96+
model.load_state_dict(state_dict)
97+
model.eval()
98+
99+
prior = GatedPixelCNN(args.k, args.hidden_size_prior, args.num_layers).to(args.device)
100+
optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr)
101+
102+
best_loss = -1.
103+
for epoch in range(args.num_epochs):
104+
train(train_loader, model, prior, optimizer, args, writer)
105+
loss, _ = test(valid_loader, model, prior, args, writer)
106+
107+
# reconstruction = generate_samples(fixed_images, model, args)
108+
# grid = make_grid(reconstruction.cpu(), nrow=8, range=(-1, 1), normalize=True)
109+
# writer.add_image('reconstruction', grid, epoch + 1)
110+
111+
if (epoch == 0) or (loss < best_loss):
112+
best_loss = loss
113+
with open(save_filename, 'wb') as f:
114+
torch.save(prior.state_dict(), f)
115+
116+
if __name__ == '__main__':
117+
import argparse
118+
import os
119+
import multiprocessing as mp
120+
121+
parser = argparse.ArgumentParser(description='Prior VQ-VAE')
122+
123+
# General
124+
parser.add_argument('--data-folder', type=str,
125+
help='name of the data folder')
126+
parser.add_argument('--model', type=str,
127+
help='filename containing the model')
128+
129+
# Latent space
130+
parser.add_argument('--hidden-size-vae', type=int, default=256,
131+
help='size of the latent vectors (default: 256)')
132+
parser.add_argument('--hidden-size-prior', type=int, default=64,
133+
help='hidden size for the PixelCNN prior (default: 64)')
134+
parser.add_argument('--k', type=int, default=512,
135+
help='number of latent vectors (default: 512)')
136+
parser.add_argument('--num_layers', type=int, default=15,
137+
help='number of layers for the PixelCNN prior (default: 15)')
138+
139+
# Optimization
140+
parser.add_argument('--batch-size', type=int, default=128,
141+
help='batch size (default: 128)')
142+
parser.add_argument('--num-epochs', type=int, default=100,
143+
help='number of epochs (default: 100)')
144+
parser.add_argument('--lr', type=float, default=3e-4,
145+
help='learning rate for Adam optimizer (default: 3e-4)')
146+
147+
# Miscellaneous
148+
parser.add_argument('--output-folder', type=str, default='prior',
149+
help='name of the output folder (default: prior)')
150+
parser.add_argument('--num-workers', type=int, default=mp.cpu_count() - 1,
151+
help='number of workers for trajectories sampling (default: {0})'.format(mp.cpu_count() - 1))
152+
parser.add_argument('--device', type=str, default='cpu',
153+
help='set the device (cpu or cuda, default: cpu)')
154+
155+
args = parser.parse_args()
156+
157+
# Create logs and models folder if they don't exist
158+
if not os.path.exists('./logs'):
159+
os.makedirs('./logs')
160+
if not os.path.exists('./models'):
161+
os.makedirs('./models')
162+
# Device
163+
args.device = torch.device(args.device
164+
if torch.cuda.is_available() else 'cpu')
165+
# Slurm
166+
if 'SLURM_JOB_ID' in os.environ:
167+
args.output_folder += '-{0}'.format(os.environ['SLURM_JOB_ID'])
168+
if not os.path.exists('./models/{0}'.format(args.output_folder)):
169+
os.makedirs('./models/{0}'.format(args.output_folder))
170+
args.steps = 0
171+
172+
main(args)

0 commit comments

Comments
 (0)