Skip to content

Commit e679e05

Browse files
committed
Add tanh pre exp
1 parent dcf2953 commit e679e05

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

flows.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,12 @@ def __init__(self,
6464
num_hidden,
6565
num_cond_inputs=None,
6666
s_act='tanh',
67-
t_act='relu'):
67+
t_act='relu',
68+
pre_exp_tanh=False):
6869
super(MADESplit, self).__init__()
6970

71+
self.pre_exp_tanh = pre_exp_tanh
72+
7073
activations = {'relu': nn.ReLU, 'sigmoid': nn.Sigmoid, 'tanh': nn.Tanh}
7174

7275
input_mask = get_mask(num_inputs, num_hidden, num_inputs,
@@ -102,6 +105,9 @@ def forward(self, inputs, cond_inputs=None, mode='direct'):
102105

103106
h = self.t_joiner(inputs, cond_inputs)
104107
a = self.t_trunk(h)
108+
109+
if self.pre_exp_tanh:
110+
a = F.tanh(a)
105111

106112
u = (inputs - m) * torch.exp(-a)
107113
return u, -a.sum(-1, keepdim=True)
@@ -115,6 +121,9 @@ def forward(self, inputs, cond_inputs=None, mode='direct'):
115121
h = self.t_joiner(x, cond_inputs)
116122
a = self.t_trunk(h)
117123

124+
if self.pre_exp_tanh:
125+
a = F.tanh(a)
126+
118127
x[:, i_col] = inputs[:, i_col] * torch.exp(
119128
a[:, i_col]) + m[:, i_col]
120129
return x, -a.sum(-1, keepdim=True)
@@ -128,7 +137,8 @@ def __init__(self,
128137
num_inputs,
129138
num_hidden,
130139
num_cond_inputs=None,
131-
act='relu'):
140+
act='relu',
141+
pre_exp_tanh=False):
132142
super(MADE, self).__init__()
133143

134144
activations = {'relu': nn.ReLU, 'sigmoid': nn.Sigmoid, 'tanh': nn.Tanh}

0 commit comments

Comments
 (0)