|
1 | 1 | import torch
|
2 | 2 | import torch.nn as nn
|
| 3 | +import torch.nn.functional as F |
3 | 4 |
|
4 | 5 |
|
5 | 6 | def to_scalar(arr):
|
6 | 7 | if type(arr) == list:
|
7 |
| - return [x.cpu().data.tolist()[0] for x in arr] |
| 8 | + return [x.item() for x in arr] |
8 | 9 | else:
|
9 |
| - return arr.cpu().data.tolist()[0] |
| 10 | + return arr.item() |
10 | 11 |
|
11 | 12 |
|
12 |
| -def euclidean_distance(z_e_x, emb): |
13 |
| - dists = torch.pow( |
14 |
| - z_e_x.unsqueeze(1) - emb[None, :, :, None, None], |
15 |
| - 2 |
16 |
| - ).sum(2) |
17 |
| - return dists |
| 13 | +def weights_init(m): |
| 14 | + classname = m.__class__.__name__ |
| 15 | + if classname.find('Conv') != -1: |
| 16 | + nn.init.xavier_uniform_(m.weight.data) |
| 17 | + m.bias.data.fill_(0) |
| 18 | + |
| 19 | + |
| 20 | +class ResBlock(nn.Module): |
| 21 | + def __init__(self, dim): |
| 22 | + super(ResBlock, self).__init__() |
| 23 | + self.block = nn.Sequential( |
| 24 | + nn.ReLU(True), |
| 25 | + nn.Conv2d(dim, dim, 3, 1, 1), |
| 26 | + nn.ReLU(True), |
| 27 | + nn.Conv2d(dim, dim, 1), |
| 28 | + ) |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + return x + self.block(x) |
18 | 32 |
|
19 | 33 |
|
20 | 34 | class AutoEncoder(nn.Module):
|
21 |
| - def __init__(self): |
| 35 | + def __init__(self, input_dim, dim, K=512): |
22 | 36 | super(AutoEncoder, self).__init__()
|
23 | 37 | self.encoder = nn.Sequential(
|
24 |
| - nn.Conv2d(1, 16, 4, 2, 1), |
25 |
| - nn.BatchNorm2d(16), |
26 |
| - nn.ReLU(True), |
27 |
| - nn.Conv2d(16, 32, 4, 2, 1), |
28 |
| - nn.BatchNorm2d(32), |
| 38 | + nn.Conv2d(input_dim, dim, 4, 2, 1), |
29 | 39 | nn.ReLU(True),
|
30 |
| - nn.Conv2d(32, 64, 1, 1, 0), |
31 |
| - nn.BatchNorm2d(64), |
| 40 | + nn.Conv2d(dim, dim, 4, 2, 1), |
| 41 | + ResBlock(dim), |
| 42 | + ResBlock(dim), |
32 | 43 | )
|
33 | 44 |
|
34 |
| - self.embedding = nn.Embedding(512, 64) |
| 45 | + self.embedding = nn.Embedding(K, dim) |
| 46 | + # self.embedding.weight.data.copy_(1./K * torch.randn(K, 256)) |
| 47 | + self.embedding.weight.data.uniform_(-1./K, 1./K) |
35 | 48 |
|
36 | 49 | self.decoder = nn.Sequential(
|
37 |
| - nn.Conv2d(64, 32, 1, 1, 0), |
38 |
| - nn.BatchNorm2d(32), |
| 50 | + ResBlock(dim), |
| 51 | + ResBlock(dim), |
39 | 52 | nn.ReLU(True),
|
40 |
| - nn.ConvTranspose2d(32, 16, 4, 2, 1), |
41 |
| - nn.BatchNorm2d(16), |
| 53 | + nn.ConvTranspose2d(dim, dim, 4, 2, 1), |
42 | 54 | nn.ReLU(True),
|
43 |
| - nn.ConvTranspose2d(16, 1, 4, 2, 1), |
44 |
| - nn.Sigmoid() |
| 55 | + nn.ConvTranspose2d(dim, input_dim, 4, 2, 1), |
| 56 | + nn.Tanh() |
45 | 57 | )
|
46 | 58 |
|
47 |
| - def forward(self, x): |
| 59 | + self.apply(weights_init) |
| 60 | + |
| 61 | + def encode(self, x): |
48 | 62 | z_e_x = self.encoder(x)
|
49 |
| - B, C, H, W = z_e_x.size() |
50 | 63 |
|
51 |
| - dists = euclidean_distance(z_e_x, self.embedding.weight) |
52 |
| - latents = dists.min(1)[1] |
| 64 | + z_e_x_transp = z_e_x.permute(0, 2, 3, 1) # (B, H, W, C) |
| 65 | + emb = self.embedding.weight.transpose(0, 1) # (C, K) |
| 66 | + dists = torch.pow( |
| 67 | + z_e_x_transp.unsqueeze(4) - emb[None, None, None], |
| 68 | + 2 |
| 69 | + ).sum(-2) |
| 70 | + latents = dists.min(-1)[1] |
| 71 | + return latents, z_e_x |
53 | 72 |
|
| 73 | + def decode(self, latents): |
54 | 74 | shp = latents.size() + (-1, )
|
55 |
| - z_q_x = self.embedding(latents.view(-1)).view(*shp) |
56 |
| - z_q_x = z_q_x.permute(0, 3, 1, 2) |
57 |
| - |
| 75 | + z_q_x = self.embedding(latents.view(latents.size(0), -1)) # (B * H * W, C) |
| 76 | + z_q_x = z_q_x.view(*shp).permute(0, 3, 1, 2) # (B, C, H, W) |
58 | 77 | x_tilde = self.decoder(z_q_x)
|
| 78 | + return x_tilde, z_q_x |
| 79 | + |
| 80 | + def forward(self, x): |
| 81 | + latents, z_e_x = self.encode(x) |
| 82 | + x_tilde, z_q_x = self.decode(latents) |
59 | 83 | return x_tilde, z_e_x, z_q_x
|
| 84 | + |
| 85 | + |
| 86 | +class GatedActivation(nn.Module): |
| 87 | + def __init__(self): |
| 88 | + super().__init__() |
| 89 | + |
| 90 | + def forward(self, x): |
| 91 | + x, y = x.chunk(2, dim=1) |
| 92 | + return F.tanh(x) * F.sigmoid(y) |
| 93 | + |
| 94 | + |
| 95 | +class GatedMaskedConv2d(nn.Module): |
| 96 | + def __init__(self, mask_type, dim, kernel, residual=True, n_classes=10): |
| 97 | + super().__init__() |
| 98 | + assert kernel % 2 == 1, print("Kernel size must be odd") |
| 99 | + self.mask_type = mask_type |
| 100 | + self.residual = residual |
| 101 | + |
| 102 | + self.class_cond_embedding = nn.Embedding( |
| 103 | + n_classes, 2 * dim |
| 104 | + ) |
| 105 | + |
| 106 | + kernel_shp = (kernel // 2 + 1, kernel) # (ceil(n/2), n) |
| 107 | + padding_shp = (kernel // 2, kernel // 2) |
| 108 | + self.vert_stack = nn.Conv2d( |
| 109 | + dim, dim * 2, |
| 110 | + kernel_shp, 1, padding_shp |
| 111 | + ) |
| 112 | + |
| 113 | + self.vert_to_horiz = nn.Conv2d(2 * dim, 2 * dim, 1) |
| 114 | + |
| 115 | + kernel_shp = (1, kernel // 2 + 1) |
| 116 | + padding_shp = (0, kernel // 2) |
| 117 | + self.horiz_stack = nn.Conv2d( |
| 118 | + dim, dim * 2, |
| 119 | + kernel_shp, 1, padding_shp |
| 120 | + ) |
| 121 | + |
| 122 | + self.horiz_resid = nn.Conv2d(dim, dim, 1) |
| 123 | + |
| 124 | + self.gate = GatedActivation() |
| 125 | + |
| 126 | + def make_causal(self): |
| 127 | + self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row |
| 128 | + self.horiz_stack.weight.data[:, :, :, -1].zero_() # Mask final column |
| 129 | + |
| 130 | + def forward(self, x_v, x_h, h): |
| 131 | + if self.mask_type == 'A': |
| 132 | + self.make_causal() |
| 133 | + |
| 134 | + h = self.class_cond_embedding(h) |
| 135 | + h_vert = self.vert_stack(x_v) |
| 136 | + h_vert = h_vert[:, :, :x_v.size(-1), :] |
| 137 | + out_v = self.gate(h_vert + h[:, :, None, None]) |
| 138 | + |
| 139 | + h_horiz = self.horiz_stack(x_h) |
| 140 | + h_horiz = h_horiz[:, :, :, :x_h.size(-2)] |
| 141 | + v2h = self.vert_to_horiz(h_vert) |
| 142 | + |
| 143 | + out = self.gate(v2h + h_horiz + h[:, :, None, None]) |
| 144 | + if self.residual: |
| 145 | + out_h = self.horiz_resid(out) + x_h |
| 146 | + else: |
| 147 | + out_h = self.horiz_resid(out) |
| 148 | + |
| 149 | + return out_v, out_h |
| 150 | + |
| 151 | + |
| 152 | +class GatedPixelCNN(nn.Module): |
| 153 | + def __init__(self, input_dim=256, dim=64, n_layers=15): |
| 154 | + super().__init__() |
| 155 | + self.dim = 64 |
| 156 | + |
| 157 | + # Create embedding layer to embed input |
| 158 | + self.embedding = nn.Embedding(input_dim, dim) |
| 159 | + |
| 160 | + # Building the PixelCNN layer by layer |
| 161 | + self.layers = nn.ModuleList() |
| 162 | + |
| 163 | + # Initial block with Mask-A convolution |
| 164 | + # Rest with Mask-B convolutions |
| 165 | + for i in range(n_layers): |
| 166 | + mask_type = 'A' if i == 0 else 'B' |
| 167 | + kernel = 7 if i == 0 else 3 |
| 168 | + residual = False if i == 0 else True |
| 169 | + |
| 170 | + self.layers.append( |
| 171 | + GatedMaskedConv2d(mask_type, dim, kernel, residual) |
| 172 | + ) |
| 173 | + |
| 174 | + # Add the output layer |
| 175 | + self.output_conv = nn.Sequential( |
| 176 | + nn.Conv2d(dim, dim, 1), |
| 177 | + nn.ReLU(True), |
| 178 | + nn.Conv2d(dim, input_dim, 1) |
| 179 | + ) |
| 180 | + |
| 181 | + def forward(self, x, label): |
| 182 | + shp = x.size() + (-1, ) |
| 183 | + x = self.embedding(x.view(-1)).view(shp) # (B, H, W, C) |
| 184 | + x = x.permute(0, 3, 1, 2) # (B, C, W, W) |
| 185 | + |
| 186 | + x_v, x_h = (x, x) |
| 187 | + for i, layer in enumerate(self.layers): |
| 188 | + x_v, x_h = layer(x_v, x_h, label) |
| 189 | + |
| 190 | + return self.output_conv(x_h) |
| 191 | + |
| 192 | + def generate(self, label, shape=(8, 8), batch_size=64): |
| 193 | + x = torch.zeros(batch_size, *shape).long().cuda() |
| 194 | + |
| 195 | + for i in range(shape[0]): |
| 196 | + for j in range(shape[1]): |
| 197 | + logits = self.forward(x, label) |
| 198 | + probs = F.softmax(logits[:, :, i, j], -1) |
| 199 | + x.data[:, i, j].copy_( |
| 200 | + probs.multinomial(1).squeeze().data |
| 201 | + ) |
| 202 | + return x |
0 commit comments