|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from torchvision import datasets, transforms |
| 4 | +from modules import GatedPixelCNN |
| 5 | +import numpy as np |
| 6 | +from torchvision.utils import save_image |
| 7 | +import time |
| 8 | + |
| 9 | + |
| 10 | +BATCH_SIZE = 32 |
| 11 | +N_EPOCHS = 100 |
| 12 | +PRINT_INTERVAL = 100 |
| 13 | +ALWAYS_SAVE = True |
| 14 | +DATASET = 'FashionMNIST' # CIFAR10 | MNIST | FashionMNIST |
| 15 | +NUM_WORKERS = 4 |
| 16 | + |
| 17 | +LATENT_SHAPE = (28, 28) # (8, 8) -> 32x32 images, (7, 7) -> 28x28 images |
| 18 | +INPUT_DIM = 3 # 3 (RGB) | 1 (Grayscale) |
| 19 | +K = 256 |
| 20 | +DIM = 64 |
| 21 | +N_LAYERS = 15 |
| 22 | +LR = 3e-4 |
| 23 | + |
| 24 | + |
| 25 | +train_loader = torch.utils.data.DataLoader( |
| 26 | + eval('datasets.'+DATASET)( |
| 27 | + '../data/{}/'.format(DATASET), train=True, download=True, |
| 28 | + transform=transforms.ToTensor(), |
| 29 | + ), batch_size=BATCH_SIZE, shuffle=False, |
| 30 | + num_workers=NUM_WORKERS, pin_memory=True |
| 31 | +) |
| 32 | +test_loader = torch.utils.data.DataLoader( |
| 33 | + eval('datasets.'+DATASET)( |
| 34 | + '../data/{}/'.format(DATASET), train=False, |
| 35 | + transform=transforms.ToTensor(), |
| 36 | + ), batch_size=BATCH_SIZE, shuffle=False, |
| 37 | + num_workers=NUM_WORKERS, pin_memory=True |
| 38 | +) |
| 39 | + |
| 40 | +model = GatedPixelCNN(K, DIM, N_LAYERS).cuda() |
| 41 | +criterion = nn.CrossEntropyLoss().cuda() |
| 42 | +opt = torch.optim.Adam(model.parameters(), lr=LR) |
| 43 | + |
| 44 | + |
| 45 | +def train(): |
| 46 | + train_loss = [] |
| 47 | + for batch_idx, (x, label) in enumerate(train_loader): |
| 48 | + start_time = time.time() |
| 49 | + x = (x[:, 0] * (K-1)).long().cuda() |
| 50 | + label = label.cuda() |
| 51 | + |
| 52 | + # Train PixelCNN with images |
| 53 | + logits = model(x, label) |
| 54 | + logits = logits.permute(0, 2, 3, 1).contiguous() |
| 55 | + |
| 56 | + loss = criterion( |
| 57 | + logits.view(-1, K), |
| 58 | + x.view(-1) |
| 59 | + ) |
| 60 | + |
| 61 | + opt.zero_grad() |
| 62 | + loss.backward() |
| 63 | + opt.step() |
| 64 | + |
| 65 | + train_loss.append(loss.item()) |
| 66 | + |
| 67 | + if (batch_idx + 1) % PRINT_INTERVAL == 0: |
| 68 | + print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format( |
| 69 | + batch_idx * len(x), len(train_loader.dataset), |
| 70 | + PRINT_INTERVAL * batch_idx / len(train_loader), |
| 71 | + np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0), |
| 72 | + time.time() - start_time |
| 73 | + )) |
| 74 | + |
| 75 | + |
| 76 | +def test(): |
| 77 | + start_time = time.time() |
| 78 | + val_loss = [] |
| 79 | + with torch.no_grad(): |
| 80 | + for batch_idx, (x, label) in enumerate(test_loader): |
| 81 | + x = (x[:, 0] * (K-1)).long().cuda() |
| 82 | + label = label.cuda() |
| 83 | + |
| 84 | + logits = model(x, label) |
| 85 | + logits = logits.permute(0, 2, 3, 1).contiguous() |
| 86 | + loss = criterion( |
| 87 | + logits.view(-1, K), |
| 88 | + x.view(-1) |
| 89 | + ) |
| 90 | + val_loss.append(loss.item()) |
| 91 | + |
| 92 | + print('Validation Completed!\tLoss: {} Time: {}'.format( |
| 93 | + np.asarray(val_loss).mean(0), |
| 94 | + time.time() - start_time |
| 95 | + )) |
| 96 | + return np.asarray(val_loss).mean(0) |
| 97 | + |
| 98 | + |
| 99 | +def generate_samples(): |
| 100 | + label = torch.arange(10).expand(10, 10).contiguous().view(-1) |
| 101 | + label = label.long().cuda() |
| 102 | + |
| 103 | + x_tilde = model.generate(label, shape=LATENT_SHAPE, batch_size=100) |
| 104 | + images = x_tilde.cpu().data.float() / (K - 1) |
| 105 | + |
| 106 | + save_image( |
| 107 | + images[:, None], |
| 108 | + 'samples/pixelcnn_baseline_samples_{}.png'.format(DATASET), |
| 109 | + nrow=10 |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +BEST_LOSS = 999 |
| 114 | +LAST_SAVED = -1 |
| 115 | +for epoch in range(1, N_EPOCHS): |
| 116 | + print("\nEpoch {}:".format(epoch)) |
| 117 | + train() |
| 118 | + cur_loss = test() |
| 119 | + |
| 120 | + if ALWAYS_SAVE or cur_loss <= BEST_LOSS: |
| 121 | + BEST_LOSS = cur_loss |
| 122 | + LAST_SAVED = epoch |
| 123 | + |
| 124 | + print("Saving model!") |
| 125 | + torch.save(model.state_dict(), 'models/{}_pixelcnn.pt'.format(DATASET)) |
| 126 | + else: |
| 127 | + print("Not saving model! Last saved: {}".format(LAST_SAVED)) |
| 128 | + |
| 129 | + generate_samples() |
0 commit comments