1
1
import torch
2
2
import torch .nn as nn
3
+ import torch .nn .functional as F
4
+ from torch .autograd import Variable
3
5
4
6
5
7
def to_scalar (arr ):
@@ -9,13 +11,19 @@ def to_scalar(arr):
9
11
return arr .cpu ().data .tolist ()[0 ]
10
12
11
13
14
+ def weights_init (m ):
15
+ classname = m .__class__ .__name__
16
+ if classname .find ('Conv' ) != - 1 :
17
+ nn .init .xavier_uniform (m .weight .data )
18
+ m .bias .data .fill_ (0 )
19
+
20
+
12
21
class ResBlock (nn .Module ):
13
22
def __init__ (self , dim ):
14
23
super (ResBlock , self ).__init__ ()
15
24
self .block = nn .Sequential (
16
25
nn .ReLU (True ),
17
26
nn .Conv2d (dim , dim , 3 , 1 , 1 ),
18
- nn .BatchNorm2d (dim ),
19
27
nn .ReLU (True ),
20
28
nn .Conv2d (dim , dim , 1 )
21
29
)
@@ -25,39 +33,33 @@ def forward(self, x):
25
33
26
34
27
35
class AutoEncoder (nn .Module ):
28
- def __init__ (self ):
36
+ def __init__ (self , K = 512 ):
29
37
super (AutoEncoder , self ).__init__ ()
30
38
self .encoder = nn .Sequential (
31
39
nn .Conv2d (3 , 256 , 4 , 2 , 1 ),
32
- nn .BatchNorm2d (256 ),
33
40
nn .ReLU (True ),
34
41
nn .Conv2d (256 , 256 , 4 , 2 , 1 ),
35
- nn .BatchNorm2d ( 256 ),
42
+ nn .ReLU ( True ),
36
43
ResBlock (256 ),
37
- nn .BatchNorm2d (256 ),
38
44
ResBlock (256 ),
39
- nn .BatchNorm2d (256 )
40
45
)
41
46
42
- self .embedding = nn .Embedding (512 , 256 )
43
- self .embedding .weight .data .copy_ (1. / 512 * torch .randn (512 , 256 ))
47
+ self .embedding = nn .Embedding (K , 256 )
48
+ self .embedding .weight .data .copy_ (1. / K * torch .randn (K , 256 ))
44
49
45
50
self .decoder = nn .Sequential (
46
51
ResBlock (256 ),
47
- nn .BatchNorm2d (256 ),
48
52
ResBlock (256 ),
49
- nn .BatchNorm2d (256 ),
50
- nn .ReLU (True ),
51
53
nn .ConvTranspose2d (256 , 256 , 4 , 2 , 1 ),
52
- nn .BatchNorm2d (256 ),
53
54
nn .ReLU (True ),
54
55
nn .ConvTranspose2d (256 , 3 , 4 , 2 , 1 ),
55
- nn .Tanh ()
56
+ nn .Sigmoid ()
56
57
)
57
58
58
- def forward (self , x ):
59
+ self .apply (weights_init )
60
+
61
+ def encode (self , x ):
59
62
z_e_x = self .encoder (x )
60
- B , C , H , W = z_e_x .size ()
61
63
62
64
z_e_x_transp = z_e_x .permute (0 , 2 , 3 , 1 ) # (B, H, W, C)
63
65
emb = self .embedding .weight .transpose (0 , 1 ) # (C, K)
@@ -66,8 +68,78 @@ def forward(self, x):
66
68
2
67
69
).sum (- 2 )
68
70
latents = dists .min (- 1 )[1 ]
71
+ return latents , z_e_x
69
72
70
- z_q_x = self .embedding (latents .view (latents .size (0 ), - 1 ))
71
- z_q_x = z_q_x .view (B , H , W , C ).permute (0 , 3 , 1 , 2 )
73
+ def decode (self , latents ):
74
+ shp = latents .size () + (- 1 , )
75
+ z_q_x = self .embedding (latents .view (latents .size (0 ), - 1 )) # (B * H * W, C)
76
+ z_q_x = z_q_x .view (* shp ).permute (0 , 3 , 1 , 2 ) # (B, C, H, W)
72
77
x_tilde = self .decoder (z_q_x )
78
+ return x_tilde , z_q_x
79
+
80
+ def forward (self , x ):
81
+ latents , z_e_x = self .encode (x )
82
+ x_tilde , z_q_x = self .decode (latents )
73
83
return x_tilde , z_e_x , z_q_x
84
+
85
+
86
+ class MaskedConv2d (nn .Conv2d ):
87
+ def __init__ (self , mask_type , * args , ** kwargs ):
88
+ super (MaskedConv2d , self ).__init__ (* args , ** kwargs )
89
+ assert mask_type in {'A' , 'B' }
90
+ self .register_buffer ('mask' , self .weight .data .clone ())
91
+ _ , _ , kH , kW = self .weight .size ()
92
+ self .mask .fill_ (1 )
93
+ self .mask [:, :, kH // 2 , kW // 2 + (mask_type == 'B' ):] = 0
94
+ self .mask [:, :, kH // 2 + 1 :] = 0
95
+
96
+ def forward (self , x ):
97
+ self .weight .data *= self .mask
98
+ return super (MaskedConv2d , self ).forward (x )
99
+
100
+
101
+ class PixelCNN (nn .Module ):
102
+ def __init__ (self , dim = 64 , n_layers = 4 ):
103
+ super ().__init__ ()
104
+ self .dim = 64
105
+
106
+ # Create embedding layer to embed input
107
+ self .embedding = nn .Embedding (256 , dim )
108
+
109
+ # Building the PixelCNN layer by layer
110
+ net = []
111
+
112
+ # Initial block with Mask-A convolution
113
+ # Rest with Mask-B convolutions
114
+ for i in range (n_layers ):
115
+ mask_type = 'A' if i == 0 else 'B'
116
+ net .extend ([
117
+ MaskedConv2d (mask_type , dim , dim , 7 , 1 , 3 , bias = False ),
118
+ nn .BatchNorm2d (dim ),
119
+ nn .ReLU (True )
120
+ ])
121
+
122
+ # Add the output layer
123
+ net .append (nn .Conv2d (dim , 256 , 1 ))
124
+
125
+ self .net = nn .Sequential (* net )
126
+
127
+ def forward (self , x ):
128
+ shp = x .size () + (- 1 , )
129
+ x = self .embedding (x .view (- 1 )).view (shp ) # (B, H, W, C)
130
+ x = x .permute (0 , 3 , 1 , 2 ) # (B, C, W, W)
131
+ return self .net (x )
132
+
133
+ def generate (self , batch_size = 64 ):
134
+ x = Variable (
135
+ torch .zeros (64 , 8 , 8 ).long ()
136
+ ).cuda ()
137
+
138
+ for i in range (8 ):
139
+ for j in range (8 ):
140
+ logits = self .forward (x )
141
+ probs = F .softmax (logits [:, :, i , j ], - 1 )
142
+ x .data [:, i , j ].copy_ (
143
+ probs .multinomial (1 ).squeeze ().data
144
+ )
145
+ return x
0 commit comments