@@ -15,29 +15,39 @@ def to_scalar(arr):
15
15
def weights_init (m ):
16
16
classname = m .__class__ .__name__
17
17
if classname .find ('Conv' ) != - 1 :
18
- nn .init .xavier_uniform_ (m .weight .data )
19
- m .bias .data .fill_ (0 )
18
+ try :
19
+ nn .init .xavier_uniform_ (m .weight .data )
20
+ m .bias .data .fill_ (0 )
21
+ except AttributeError :
22
+ print ("Skipping initialization of " , classname )
20
23
21
24
22
25
class VAE (nn .Module ):
23
26
def __init__ (self , input_dim , dim , z_dim ):
24
27
super ().__init__ ()
25
28
self .encoder = nn .Sequential (
26
29
nn .Conv2d (input_dim , dim , 4 , 2 , 1 ),
30
+ nn .BatchNorm2d (dim ),
27
31
nn .ReLU (True ),
28
32
nn .Conv2d (dim , dim , 4 , 2 , 1 ),
33
+ nn .BatchNorm2d (dim ),
29
34
nn .ReLU (True ),
30
35
nn .Conv2d (dim , dim , 5 , 1 , 0 ),
36
+ nn .BatchNorm2d (dim ),
31
37
nn .ReLU (True ),
32
38
nn .Conv2d (dim , z_dim * 2 , 3 , 1 , 0 ),
39
+ nn .BatchNorm2d (z_dim * 2 )
33
40
)
34
41
35
42
self .decoder = nn .Sequential (
36
43
nn .ConvTranspose2d (z_dim , dim , 3 , 1 , 0 ),
44
+ nn .BatchNorm2d (dim ),
37
45
nn .ReLU (True ),
38
46
nn .ConvTranspose2d (dim , dim , 5 , 1 , 0 ),
47
+ nn .BatchNorm2d (dim ),
39
48
nn .ReLU (True ),
40
49
nn .ConvTranspose2d (dim , dim , 4 , 2 , 1 ),
50
+ nn .BatchNorm2d (dim ),
41
51
nn .ReLU (True ),
42
52
nn .ConvTranspose2d (dim , input_dim , 4 , 2 , 1 ),
43
53
nn .Tanh ()
@@ -204,7 +214,7 @@ def forward(self, x_v, x_h, h):
204
214
class GatedPixelCNN (nn .Module ):
205
215
def __init__ (self , input_dim = 256 , dim = 64 , n_layers = 15 ):
206
216
super ().__init__ ()
207
- self .dim = 64
217
+ self .dim = dim
208
218
209
219
# Create embedding layer to embed input
210
220
self .embedding = nn .Embedding (input_dim , dim )
@@ -225,11 +235,13 @@ def __init__(self, input_dim=256, dim=64, n_layers=15):
225
235
226
236
# Add the output layer
227
237
self .output_conv = nn .Sequential (
228
- nn .Conv2d (dim , dim , 1 ),
238
+ nn .Conv2d (dim , 512 , 1 ),
229
239
nn .ReLU (True ),
230
- nn .Conv2d (dim , input_dim , 1 )
240
+ nn .Conv2d (512 , input_dim , 1 )
231
241
)
232
242
243
+ self .apply (weights_init )
244
+
233
245
def forward (self , x , label ):
234
246
shp = x .size () + (- 1 , )
235
247
x = self .embedding (x .view (- 1 )).view (shp ) # (B, H, W, C)
0 commit comments