Skip to content

Commit cde1426

Browse files
Merge pull request #3 from ritheshkumar95/tristan/miniimagenet
Mini-Imagenet experiments
2 parents 3be148b + bad9cb4 commit cde1426

File tree

5 files changed

+465
-2
lines changed

5 files changed

+465
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ run.sh
107107
# Logs & Saves
108108
logs/
109109
saves/
110+
models/
110111

111112
# Slurm
112113
*.out

datasets.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import csv
3+
import torch.utils.data as data
4+
from PIL import Image
5+
6+
def pil_loader(path):
7+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
8+
# Borrowed from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
9+
with open(path, 'rb') as f:
10+
img = Image.open(f)
11+
return img.convert('RGB')
12+
13+
class MiniImagenet(data.Dataset):
14+
15+
base_folder = '/data/lisa/data/miniimagenet'
16+
filename = 'miniimagenet.zip'
17+
splits = {
18+
'train': 'train.csv',
19+
'valid': 'val.csv',
20+
'test': 'test.csv'
21+
}
22+
23+
def __init__(self, root, train=False, valid=False, test=False,
24+
transform=None, target_transform=None, download=False):
25+
super(MiniImagenet, self).__init__()
26+
self.root = root
27+
self.train = train
28+
self.valid = valid
29+
self.test = test
30+
self.transform = transform
31+
self.target_transform = target_transform
32+
33+
if not (((train ^ valid ^ test) ^ (train & valid & test))):
34+
raise ValueError('One and only one of `train`, `valid` or `test` '
35+
'must be True (train={0}, valid={1}, test={2}).'.format(train,
36+
valid, test))
37+
38+
self.image_folder = os.path.join(os.path.expanduser(root), 'images')
39+
if train:
40+
split = self.splits['train']
41+
elif valid:
42+
split = self.splits['valid']
43+
elif test:
44+
split = self.splits['test']
45+
else:
46+
raise ValueError('Unknown split.')
47+
self.split_filename = os.path.join(os.path.expanduser(root), split)
48+
if download:
49+
self.download()
50+
if not self._check_exists():
51+
raise RuntimeError('Dataset not found. You can use `download=True` '
52+
'to download it')
53+
54+
# Extract filenames and labels
55+
self._data = []
56+
with open(self.split_filename, 'r') as f:
57+
reader = csv.reader(f)
58+
next(reader) # Skip the header
59+
for line in reader:
60+
self._data.append(tuple(line))
61+
self._fit_label_encoding()
62+
63+
def __getitem__(self, index):
64+
filename, label = self._data[index]
65+
image = pil_loader(os.path.join(self.image_folder, filename))
66+
label = self._label_encoder[label]
67+
if self.transform is not None:
68+
image = self.transform(image)
69+
if self.target_transform is not None:
70+
label = self.target_transform(label)
71+
72+
return image, label
73+
74+
def _fit_label_encoding(self):
75+
_, labels = zip(*self._data)
76+
unique_labels = set(labels)
77+
self._label_encoder = dict((label, idx)
78+
for (idx, label) in enumerate(unique_labels))
79+
80+
def _check_exists(self):
81+
return (os.path.exists(self.image_folder)
82+
and os.path.exists(self.split_filename))
83+
84+
def download(self):
85+
from shutil import copyfile
86+
from zipfile import ZipFile
87+
88+
# If the image folder already exists, break
89+
if self._check_exists():
90+
return True
91+
92+
# Create folder if it does not exist
93+
root = os.path.expanduser(self.root)
94+
if not os.path.exists(root):
95+
os.makedirs(root)
96+
97+
# Copy the file to root
98+
path_source = os.path.join(self.base_folder, self.filename)
99+
path_dest = os.path.join(root, self.filename)
100+
print('Copy file `{0}` to `{1}`...'.format(path_source, path_dest))
101+
copyfile(path_source, path_dest)
102+
103+
# Extract the dataset
104+
print('Extract files from `{0}`...'.format(path_dest))
105+
with ZipFile(path_dest, 'r') as f:
106+
f.extractall(root)
107+
108+
# Copy CSV files
109+
for split in self.splits:
110+
path_source = os.path.join(self.base_folder, self.splits[split])
111+
path_dest = os.path.join(root, self.splits[split])
112+
print('Copy file `{0}` to `{1}`...'.format(path_source, path_dest))
113+
copyfile(path_source, path_dest)
114+
print('Done!')
115+
116+
def __len__(self):
117+
return len(self._data)

miniimagenet_pixelcnn_prior.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn.functional as F
4+
import json
5+
from torchvision import transforms
6+
from torchvision.utils import save_image, make_grid
7+
8+
from modules import AutoEncoder, GatedPixelCNN
9+
from datasets import MiniImagenet
10+
11+
from tensorboardX import SummaryWriter
12+
13+
def train(data_loader, model, prior, optimizer, args, writer):
14+
for images, labels in data_loader:
15+
with torch.no_grad():
16+
images = images.to(args.device)
17+
latents, _ = model.encode(images)
18+
latents = latents.detach()
19+
20+
labels = labels.to(args.device)
21+
logits = prior(latents, labels)
22+
logits = logits.permute(0, 2, 3, 1).contiguous()
23+
24+
optimizer.zero_grad()
25+
loss = F.cross_entropy(logits.view(-1, args.k),
26+
latents.view(-1))
27+
loss.backward()
28+
29+
# Logs
30+
writer.add_scalar('loss/train', loss.item(), args.steps)
31+
32+
optimizer.step()
33+
args.steps += 1
34+
35+
def test(data_loader, model, prior, args, writer):
36+
with torch.no_grad():
37+
loss = 0.
38+
for images, labels in data_loader:
39+
images = images.to(args.device)
40+
labels = labels.to(args.device)
41+
42+
latents, _ = model.encode(images)
43+
latents = latents.detach()
44+
logits = prior(latents, labels)
45+
logits = logits.permute(0, 2, 3, 1).contiguous()
46+
loss += F.cross_entropy(logits.view(-1, args.k),
47+
latents.view(-1))
48+
49+
loss /= len(data_loader)
50+
51+
# Logs
52+
writer.add_scalar('loss/valid', loss.item(), args.steps)
53+
54+
return loss.item()
55+
56+
def main(args):
57+
writer = SummaryWriter('./logs/{0}'.format(args.output_folder))
58+
save_filename = './models/{0}/prior.pt'.format(args.output_folder)
59+
60+
transform = transforms.Compose([
61+
transforms.RandomResizedCrop(128),
62+
transforms.ToTensor(),
63+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
64+
])
65+
66+
# Define the train, valid & test datasets
67+
train_dataset = MiniImagenet(args.data_folder, train=True,
68+
download=True, transform=transform)
69+
valid_dataset = MiniImagenet(args.data_folder, valid=True,
70+
download=True, transform=transform)
71+
test_dataset = MiniImagenet(args.data_folder, test=True,
72+
download=True, transform=transform)
73+
# Define the data loaders
74+
train_loader = torch.utils.data.DataLoader(train_dataset,
75+
batch_size=args.batch_size, shuffle=False,
76+
num_workers=args.num_workers, pin_memory=True)
77+
valid_loader = torch.utils.data.DataLoader(valid_dataset,
78+
batch_size=args.batch_size, shuffle=False, drop_last=True,
79+
num_workers=args.num_workers, pin_memory=True)
80+
test_loader = torch.utils.data.DataLoader(test_dataset,
81+
batch_size=16, shuffle=True)
82+
83+
# Save the label encoder
84+
with open('./models/{0}/labels.json'.format(args.output_folder), 'w') as f:
85+
json.dump(train_dataset._label_encoder, f)
86+
87+
# Fixed images for Tensorboard
88+
fixed_images, _ = next(iter(test_loader))
89+
fixed_grid = make_grid(fixed_images, nrow=8, range=(-1, 1), normalize=True)
90+
writer.add_image('original', fixed_grid, 0)
91+
92+
model = AutoEncoder(3, args.hidden_size_vae, args.k).to(args.device)
93+
with open(args.model, 'rb') as f:
94+
state_dict = torch.load(f)
95+
model.load_state_dict(state_dict)
96+
model.eval()
97+
98+
prior = GatedPixelCNN(args.k, args.hidden_size_prior,
99+
args.num_layers, n_classes=len(train_dataset._label_encoder)).to(args.device)
100+
optimizer = torch.optim.Adam(prior.parameters(), lr=args.lr)
101+
102+
best_loss = -1.
103+
for epoch in range(args.num_epochs):
104+
train(train_loader, model, prior, optimizer, args, writer)
105+
# The validation loss is not properly computed since
106+
# the classes in the train and valid splits of Mini-Imagenet
107+
# do not overlap.
108+
loss = test(valid_loader, model, prior, args, writer)
109+
110+
if (epoch == 0) or (loss < best_loss):
111+
best_loss = loss
112+
with open(save_filename, 'wb') as f:
113+
torch.save(prior.state_dict(), f)
114+
115+
if __name__ == '__main__':
116+
import argparse
117+
import os
118+
import multiprocessing as mp
119+
120+
parser = argparse.ArgumentParser(description='PixelCNN Prior for VQ-VAE')
121+
122+
# General
123+
parser.add_argument('--data-folder', type=str,
124+
help='name of the data folder')
125+
parser.add_argument('--model', type=str,
126+
help='filename containing the model')
127+
128+
# Latent space
129+
parser.add_argument('--hidden-size-vae', type=int, default=256,
130+
help='size of the latent vectors (default: 256)')
131+
parser.add_argument('--hidden-size-prior', type=int, default=64,
132+
help='hidden size for the PixelCNN prior (default: 64)')
133+
parser.add_argument('--k', type=int, default=512,
134+
help='number of latent vectors (default: 512)')
135+
parser.add_argument('--num-layers', type=int, default=15,
136+
help='number of layers for the PixelCNN prior (default: 15)')
137+
138+
# Optimization
139+
parser.add_argument('--batch-size', type=int, default=128,
140+
help='batch size (default: 128)')
141+
parser.add_argument('--num-epochs', type=int, default=100,
142+
help='number of epochs (default: 100)')
143+
parser.add_argument('--lr', type=float, default=3e-4,
144+
help='learning rate for Adam optimizer (default: 3e-4)')
145+
146+
# Miscellaneous
147+
parser.add_argument('--output-folder', type=str, default='prior',
148+
help='name of the output folder (default: prior)')
149+
parser.add_argument('--num-workers', type=int, default=mp.cpu_count() - 1,
150+
help='number of workers for trajectories sampling (default: {0})'.format(mp.cpu_count() - 1))
151+
parser.add_argument('--device', type=str, default='cpu',
152+
help='set the device (cpu or cuda, default: cpu)')
153+
154+
args = parser.parse_args()
155+
156+
# Create logs and models folder if they don't exist
157+
if not os.path.exists('./logs'):
158+
os.makedirs('./logs')
159+
if not os.path.exists('./models'):
160+
os.makedirs('./models')
161+
# Device
162+
args.device = torch.device(args.device
163+
if torch.cuda.is_available() else 'cpu')
164+
# Slurm
165+
if 'SLURM_JOB_ID' in os.environ:
166+
args.output_folder += '-{0}'.format(os.environ['SLURM_JOB_ID'])
167+
if not os.path.exists('./models/{0}'.format(args.output_folder)):
168+
os.makedirs('./models/{0}'.format(args.output_folder))
169+
args.steps = 0
170+
171+
main(args)

0 commit comments

Comments
 (0)