@@ -83,52 +83,106 @@ def forward(self, x):
83
83
return x_tilde , z_e_x , z_q_x
84
84
85
85
86
- class MaskedConv2d (nn .Conv2d ):
87
- def __init__ (self , mask_type , * args , ** kwargs ):
88
- super (MaskedConv2d , self ).__init__ (* args , ** kwargs )
89
- assert mask_type in {'A' , 'B' }
90
- self .register_buffer ('mask' , self .weight .data .clone ())
91
- _ , _ , kH , kW = self .weight .size ()
92
- self .mask .fill_ (1 )
93
- self .mask [:, :, kH // 2 , kW // 2 + (mask_type == 'B' ):] = 0
94
- self .mask [:, :, kH // 2 + 1 :] = 0
86
+ class GatedActivation (nn .Module ):
87
+ def __init__ (self ):
88
+ super ().__init__ ()
95
89
96
90
def forward (self , x ):
97
- self .weight .data *= self .mask
98
- return super (MaskedConv2d , self ).forward (x )
91
+ x , y = x .chunk (2 , dim = 1 )
92
+ return F .tanh (x ) * F .sigmoid (y )
93
+
94
+
95
+ class GatedMaskedConv2d (nn .Module ):
96
+ def __init__ (self , mask_type , dim , kernel , residual = True ):
97
+ super ().__init__ ()
98
+ assert kernel % 2 == 1 , print ("Kernel size must be odd" )
99
+ self .mask_type = mask_type
100
+ self .residual = residual
101
+
102
+ kernel_shp = (kernel // 2 + 1 , kernel ) # (ceil(n/2), n)
103
+ padding_shp = (kernel // 2 , kernel // 2 )
104
+ self .vert_stack = nn .Conv2d (
105
+ dim , dim * 2 ,
106
+ kernel_shp , 1 , padding_shp
107
+ )
108
+
109
+ self .vert_to_horiz = nn .Conv2d (2 * dim , 2 * dim , 1 )
110
+
111
+ kernel_shp = (1 , kernel // 2 + 1 )
112
+ padding_shp = (0 , kernel // 2 )
113
+ self .horiz_stack = nn .Conv2d (
114
+ dim , dim * 2 ,
115
+ kernel_shp , 1 , padding_shp
116
+ )
117
+
118
+ self .horiz_resid = nn .Conv2d (dim , dim , 1 )
119
+
120
+ self .gate = GatedActivation ()
121
+
122
+ def make_causal (self ):
123
+ self .vert_stack .weight .data [:, :, - 1 ].zero_ () # Mask final row
124
+ self .horiz_stack .weight .data [:, :, :, - 1 ].zero_ () # Mask final column
99
125
126
+ def forward (self , x_v , x_h ):
127
+ if self .mask_type == 'A' :
128
+ self .make_causal ()
100
129
101
- class PixelCNN (nn .Module ):
102
- def __init__ (self , dim = 64 , n_layers = 4 ):
130
+ h_vert = self .vert_stack (x_v )
131
+ h_vert = h_vert [:, :, :x_v .size (- 1 ), :]
132
+ out_v = self .gate (h_vert )
133
+
134
+ h_horiz = self .horiz_stack (x_h )
135
+ h_horiz = h_horiz [:, :, :, :x_h .size (- 2 )]
136
+ v2h = self .vert_to_horiz (h_vert )
137
+
138
+ out = self .gate (v2h + h_horiz )
139
+ if self .residual :
140
+ out_h = self .horiz_resid (out ) + x_h
141
+ else :
142
+ out_h = self .horiz_resid (out )
143
+
144
+ return out_v , out_h
145
+
146
+
147
+ class GatedPixelCNN (nn .Module ):
148
+ def __init__ (self , input_dim = 256 , dim = 64 , n_layers = 7 ):
103
149
super ().__init__ ()
104
150
self .dim = 64
105
151
106
152
# Create embedding layer to embed input
107
- self .embedding = nn .Embedding (256 , dim )
153
+ self .embedding = nn .Embedding (input_dim , dim )
108
154
109
155
# Building the PixelCNN layer by layer
110
- net = []
156
+ self . layers = nn . ModuleList ()
111
157
112
158
# Initial block with Mask-A convolution
113
159
# Rest with Mask-B convolutions
114
160
for i in range (n_layers ):
115
161
mask_type = 'A' if i == 0 else 'B'
116
- net .extend ([
117
- MaskedConv2d (mask_type , dim , dim , 7 , 1 , 3 , bias = False ),
118
- nn .BatchNorm2d (dim ),
119
- nn .ReLU (True )
120
- ])
162
+ kernel = 7 if i == 0 else 3
163
+ residual = False if i == 0 else True
121
164
122
- # Add the output layer
123
- net .append (nn .Conv2d (dim , 256 , 1 ))
165
+ self .layers .append (
166
+ GatedMaskedConv2d (mask_type , dim , kernel , residual )
167
+ )
124
168
125
- self .net = nn .Sequential (* net )
169
+ # Add the output layer
170
+ self .output_conv = nn .Sequential (
171
+ nn .Conv2d (dim , dim , 1 ),
172
+ nn .ReLU (True ),
173
+ nn .Conv2d (dim , input_dim , 1 )
174
+ )
126
175
127
176
def forward (self , x ):
128
177
shp = x .size () + (- 1 , )
129
178
x = self .embedding (x .view (- 1 )).view (shp ) # (B, H, W, C)
130
179
x = x .permute (0 , 3 , 1 , 2 ) # (B, C, W, W)
131
- return self .net (x )
180
+
181
+ x_v , x_h = (x , x )
182
+ for i , layer in enumerate (self .layers ):
183
+ x_v , x_h = layer (x_v , x_h )
184
+
185
+ return self .output_conv (x_h )
132
186
133
187
def generate (self , batch_size = 64 ):
134
188
x = Variable (
0 commit comments