@@ -64,9 +64,12 @@ def __init__(self,
64
64
num_hidden ,
65
65
num_cond_inputs = None ,
66
66
s_act = 'tanh' ,
67
- t_act = 'relu' ):
67
+ t_act = 'relu' ,
68
+ pre_exp_tanh = False ):
68
69
super (MADESplit , self ).__init__ ()
69
70
71
+ self .pre_exp_tanh = pre_exp_tanh
72
+
70
73
activations = {'relu' : nn .ReLU , 'sigmoid' : nn .Sigmoid , 'tanh' : nn .Tanh }
71
74
72
75
input_mask = get_mask (num_inputs , num_hidden , num_inputs ,
@@ -102,6 +105,9 @@ def forward(self, inputs, cond_inputs=None, mode='direct'):
102
105
103
106
h = self .t_joiner (inputs , cond_inputs )
104
107
a = self .t_trunk (h )
108
+
109
+ if self .pre_exp_tanh :
110
+ a = F .tanh (a )
105
111
106
112
u = (inputs - m ) * torch .exp (- a )
107
113
return u , - a .sum (- 1 , keepdim = True )
@@ -115,6 +121,9 @@ def forward(self, inputs, cond_inputs=None, mode='direct'):
115
121
h = self .t_joiner (x , cond_inputs )
116
122
a = self .t_trunk (h )
117
123
124
+ if self .pre_exp_tanh :
125
+ a = F .tanh (a )
126
+
118
127
x [:, i_col ] = inputs [:, i_col ] * torch .exp (
119
128
a [:, i_col ]) + m [:, i_col ]
120
129
return x , - a .sum (- 1 , keepdim = True )
@@ -128,7 +137,8 @@ def __init__(self,
128
137
num_inputs ,
129
138
num_hidden ,
130
139
num_cond_inputs = None ,
131
- act = 'relu' ):
140
+ act = 'relu' ,
141
+ pre_exp_tanh = False ):
132
142
super (MADE , self ).__init__ ()
133
143
134
144
activations = {'relu' : nn .ReLU , 'sigmoid' : nn .Sigmoid , 'tanh' : nn .Tanh }
0 commit comments