1
1
import torch
2
2
import torch .nn as nn
3
3
import torch .nn .functional as F
4
- from torch .autograd import Variable
5
4
6
5
7
6
def to_scalar (arr ):
@@ -33,27 +32,27 @@ def forward(self, x):
33
32
34
33
35
34
class AutoEncoder (nn .Module ):
36
- def __init__ (self , K = 512 ):
35
+ def __init__ (self , input_dim , dim , K = 512 ):
37
36
super (AutoEncoder , self ).__init__ ()
38
37
self .encoder = nn .Sequential (
39
- nn .Conv2d (3 , 256 , 4 , 2 , 1 ),
38
+ nn .Conv2d (input_dim , dim , 4 , 2 , 1 ),
40
39
nn .ReLU (True ),
41
- nn .Conv2d (256 , 256 , 4 , 2 , 1 ),
42
- ResBlock (256 ),
43
- ResBlock (256 ),
40
+ nn .Conv2d (dim , dim , 4 , 2 , 1 ),
41
+ ResBlock (dim ),
42
+ ResBlock (dim ),
44
43
)
45
44
46
- self .embedding = nn .Embedding (K , 256 )
45
+ self .embedding = nn .Embedding (K , dim )
47
46
# self.embedding.weight.data.copy_(1./K * torch.randn(K, 256))
48
47
self .embedding .weight .data .uniform_ (- 1. / K , 1. / K )
49
48
50
49
self .decoder = nn .Sequential (
51
- ResBlock (256 ),
52
- ResBlock (256 ),
50
+ ResBlock (dim ),
51
+ ResBlock (dim ),
53
52
nn .ReLU (True ),
54
- nn .ConvTranspose2d (256 , 256 , 4 , 2 , 1 ),
53
+ nn .ConvTranspose2d (dim , dim , 4 , 2 , 1 ),
55
54
nn .ReLU (True ),
56
- nn .ConvTranspose2d (256 , 3 , 4 , 2 , 1 ),
55
+ nn .ConvTranspose2d (dim , input_dim , 4 , 2 , 1 ),
57
56
nn .Tanh ()
58
57
)
59
58
@@ -94,12 +93,16 @@ def forward(self, x):
94
93
95
94
96
95
class GatedMaskedConv2d (nn .Module ):
97
- def __init__ (self , mask_type , dim , kernel , residual = True ):
96
+ def __init__ (self , mask_type , dim , kernel , residual = True , n_classes = 10 ):
98
97
super ().__init__ ()
99
98
assert kernel % 2 == 1 , print ("Kernel size must be odd" )
100
99
self .mask_type = mask_type
101
100
self .residual = residual
102
101
102
+ self .class_cond_embedding = nn .Embedding (
103
+ n_classes , 2 * dim
104
+ )
105
+
103
106
kernel_shp = (kernel // 2 + 1 , kernel ) # (ceil(n/2), n)
104
107
padding_shp = (kernel // 2 , kernel // 2 )
105
108
self .vert_stack = nn .Conv2d (
@@ -124,19 +127,20 @@ def make_causal(self):
124
127
self .vert_stack .weight .data [:, :, - 1 ].zero_ () # Mask final row
125
128
self .horiz_stack .weight .data [:, :, :, - 1 ].zero_ () # Mask final column
126
129
127
- def forward (self , x_v , x_h ):
130
+ def forward (self , x_v , x_h , h ):
128
131
if self .mask_type == 'A' :
129
132
self .make_causal ()
130
133
134
+ h = self .class_cond_embedding (h )
131
135
h_vert = self .vert_stack (x_v )
132
136
h_vert = h_vert [:, :, :x_v .size (- 1 ), :]
133
- out_v = self .gate (h_vert )
137
+ out_v = self .gate (h_vert + h [:, :, None , None ] )
134
138
135
139
h_horiz = self .horiz_stack (x_h )
136
140
h_horiz = h_horiz [:, :, :, :x_h .size (- 2 )]
137
141
v2h = self .vert_to_horiz (h_vert )
138
142
139
- out = self .gate (v2h + h_horiz )
143
+ out = self .gate (v2h + h_horiz + h [:, :, None , None ] )
140
144
if self .residual :
141
145
out_h = self .horiz_resid (out ) + x_h
142
146
else :
@@ -174,25 +178,23 @@ def __init__(self, input_dim=256, dim=64, n_layers=15):
174
178
nn .Conv2d (dim , input_dim , 1 )
175
179
)
176
180
177
- def forward (self , x ):
181
+ def forward (self , x , label ):
178
182
shp = x .size () + (- 1 , )
179
183
x = self .embedding (x .view (- 1 )).view (shp ) # (B, H, W, C)
180
184
x = x .permute (0 , 3 , 1 , 2 ) # (B, C, W, W)
181
185
182
186
x_v , x_h = (x , x )
183
187
for i , layer in enumerate (self .layers ):
184
- x_v , x_h = layer (x_v , x_h )
188
+ x_v , x_h = layer (x_v , x_h , label )
185
189
186
190
return self .output_conv (x_h )
187
191
188
- def generate (self , batch_size = 64 ):
189
- x = Variable (
190
- torch .zeros (64 , 8 , 8 ).long ()
191
- ).cuda ()
192
+ def generate (self , label , shape = (8 , 8 ), batch_size = 64 ):
193
+ x = torch .zeros (batch_size , * shape ).long ().cuda ()
192
194
193
- for i in range (8 ):
194
- for j in range (8 ):
195
- logits = self .forward (x )
195
+ for i in range (shape [ 0 ] ):
196
+ for j in range (shape [ 1 ] ):
197
+ logits = self .forward (x , label )
196
198
probs = F .softmax (logits [:, :, i , j ], - 1 )
197
199
x .data [:, i , j ].copy_ (
198
200
probs .multinomial (1 ).squeeze ().data
0 commit comments