10
10
import numpy as np
11
11
12
12
13
- class ReShape (nn .Module ):
14
- def __init__ (self , tensor_size ):
15
- super (ReShape , self ).__init__ ()
16
- self .tensor_size = tensor_size
17
-
18
- def forward (self , tensor ):
19
- return tensor .view (tensor .size (0 ), * self .tensor_size [1 :])
20
-
21
-
22
13
class ConvolutionalVAE (nn .Module ):
23
- """
24
- Convolutional Variational Auto Encoder
25
-
26
- Parameters
27
- tensor_size :: expected size of input tensor
28
- embedding_layers :: a list of (filter_size, out_channels, strides)
29
- in each intermediate layer of the encoder.
30
- A flip is used for decoder
31
- n_latent :: length of latent vecotr Z
32
- decoder_final_activation :: tanh/sigm
33
-
34
- activation, normalization, pre_nm, weight_nm, equalized, bias ::
35
- refer to core.NeuralLayers
14
+ r""" Example Convolutional Variational Auto Encoder
15
+
16
+ Args:
17
+ tensor_size: shape of tensor in BCHW
18
+ (None/any integer >0, channels, height, width)
19
+ embedding_layers: a list of (filter_size, out_channels, strides)
20
+ in each intermediate layer of the encoder. A flip is used for
21
+ decoder.
22
+ n_latent: length of latent vecotr Z
23
+ decoder_final_activation: tanh/sigm
24
+ activation, normalization, pre_nm, weight_nm, equalized, bias:
25
+ refer to core.NeuralLayers.Convolution
26
+
27
+ Return:
28
+ encoded, mu, log_var, latent, decoded, kld, mse
36
29
"""
37
30
def __init__ (self ,
38
- tensor_size = (6 , 1 , 28 , 28 ),
39
- embedding_layers = [(3 , 32 , 2 ), (3 , 64 , 2 )],
40
- n_latent = 128 ,
41
- decoder_final_activation = "tanh" ,
42
- pad = True ,
43
- activation = "relu" ,
44
- normalization = None ,
45
- pre_nm = False ,
46
- groups = 1 ,
47
- weight_nm = False ,
48
- equalized = False ,
49
- bias = False ,
31
+ tensor_size : tuple = (6 , 1 , 28 , 28 ),
32
+ embedding_layers : list = [(3 , 32 , 2 ), (3 , 64 , 2 )],
33
+ n_latent : int = 128 ,
34
+ decoder_final_activation : str = "tanh" ,
35
+ pad : bool = True ,
36
+ activation : str = "relu" ,
37
+ normalization : str = None ,
38
+ pre_nm : bool = False ,
39
+ groups : int = 1 ,
40
+ weight_nm : bool = False ,
41
+ equalized : bool = False ,
42
+ bias : bool = False ,
50
43
* args , ** kwargs ):
51
44
super (ConvolutionalVAE , self ).__init__ ()
52
45
@@ -65,23 +58,22 @@ def __init__(self,
65
58
kwargs ["equalized" ] = equalized
66
59
# encoder with Convolution layers
67
60
encoder = []
68
- _tensor_size = tensor_size
61
+ t_size = tensor_size
69
62
for f , c , s in embedding_layers :
70
- encoder .append (Convolution (_tensor_size , f , c , s , ** kwargs ))
71
- _tensor_size = encoder [- 1 ].tensor_size
63
+ encoder .append (Convolution (t_size , f , c , s , ** kwargs ))
64
+ t_size = encoder [- 1 ].tensor_size
72
65
self .encoder = nn .Sequential (* encoder )
73
66
74
67
# mu and log_var to synthesize Z
75
- self .mu = Linear (_tensor_size , n_latent , "" , 0. , bias = bias )
76
- self .log_var = Linear (_tensor_size , n_latent , "" , 0. , bias = bias )
68
+ self .mu = Linear (t_size , n_latent , "" , 0. , bias = bias )
69
+ self .log_var = Linear (t_size , n_latent , "" , 0. , bias = bias )
77
70
78
71
# decoder - (Linear layer + ReShape) to generate encoder last output
79
72
# shape, followed by inverse of encoder
80
73
decoder = []
81
74
decoder .append (Linear (self .mu .tensor_size ,
82
- int (np .prod (_tensor_size [1 :])),
83
- activation , 0. , bias = bias ))
84
- decoder .append (ReShape (_tensor_size ))
75
+ int (np .prod (t_size [1 :])), activation , 0. ,
76
+ bias = bias , out_shape = t_size [1 :]))
85
77
86
78
decoder_layers = []
87
79
for i , x in enumerate (embedding_layers [::- 1 ]):
@@ -94,10 +86,9 @@ def __init__(self,
94
86
for i , (f , c , s , o ) in enumerate (decoder_layers ):
95
87
if i == len (decoder_layers )- 1 :
96
88
kwargs ["activation" ] = None
97
- decoder .append (Convolution (_tensor_size , f , c , s ,
98
- transpose = True , ** kwargs ))
99
- decoder [- 1 ].tensor_size = o # adjusting the output tensor size
100
- _tensor_size = decoder [- 1 ].tensor_size
89
+ decoder .append (Convolution (t_size , f , c , s , transpose = True ,
90
+ maintain_out_size = True , ** kwargs ))
91
+ t_size = decoder [- 1 ].tensor_size
101
92
self .decoder = nn .Sequential (* decoder )
102
93
103
94
# Final normalization
@@ -123,6 +114,7 @@ def forward(self, tensor, noisy_tensor=None):
123
114
mse = F .mse_loss (decoded , tensor )
124
115
return encoded , mu , log_var , latent , decoded , kld , mse
125
116
117
+
126
118
# from core.NeuralLayers import Convolution, Linear
127
119
# tensor_size = (1, 1, 28, 28)
128
120
# tensor = torch.rand(*tensor_size)
0 commit comments