Skip to content

Commit 090a0ec

Browse files
Merge pull request #261 from FluxML/cl/tests
fix hardsigmoid and use float(x) instead of x/1
2 parents c08258b + 440ed3b commit 090a0ec

File tree

2 files changed

+41
-35
lines changed

2 files changed

+41
-35
lines changed

src/activations.jl

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ end
1717
# Aliases
1818
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu
1919

20-
20+
# of type float
21+
oftf(x, y) = oftype(float(x), y)
2122

2223
"""
2324
σ(x) = 1 / (1 + exp(-x))
@@ -33,13 +34,14 @@ end
3334
const sigmoid = σ
3435

3536
"""
36-
hardσ(x, a=0.2) = max(0, min(1, a * x + 0.5))
37+
hardσ(x) = max(0, min(1, (x + 3) / 6)
3738
38-
Segment-wise linear approximation of sigmoid.
39-
See [BinaryConnect: Training Deep Neural Networks withbinary weights during propagations](https://arxiv.org/abs/1511.00363).
39+
Piecewise linear approximation of sigmoid.
4040
"""
41-
hardσ(x, a=0.2) = oftype(x/1, max(zero(x/1), min(one(x/1), oftype(x/1,a) * x + oftype(x/1,0.5))))
42-
41+
hardσ(x) = max(0, min(1, (x + 3) / 6))
42+
43+
# https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
44+
4345
const hardsigmoid = hardσ
4446

4547
"""
@@ -56,7 +58,7 @@ const logsigmoid = logσ
5658
Segment-wise linear approximation of tanh. Cheaper and more computational efficient version of tanh.
5759
See [Large Scale Machine Learning](https://ronan.collobert.com/pub/matos/2004_phdthesis_lip6.pdf).
5860
"""
59-
hardtanh(x) = max(-one(x), min( one(x), x))
61+
hardtanh(x) = max(-one(x), min(one(x), x))
6062

6163
"""
6264
relu(x) = max(0, x)
@@ -73,7 +75,7 @@ Leaky [Rectified Linear Unit](https://en.wikipedia.org/wiki/Rectifier_(neural_ne
7375
activation function.
7476
You can also specify the coefficient explicitly, e.g. `leakyrelu(x, 0.01)`.
7577
"""
76-
leakyrelu(x, a = oftype(x/1, 0.01)) = max(a * x, x/1)
78+
leakyrelu(x, a=oftf(x, 0.01)) = max(a * x, x)
7779

7880
"""
7981
relu6(x) = min(max(0, x), 6)
@@ -93,8 +95,8 @@ Randomized Leaky [Rectified Linear Unit](https://arxiv.org/abs/1505.00853)
9395
activation function.
9496
You can also specify the bound explicitly, e.g. `rrelu(x, 0.0, 1.0)`.
9597
"""
96-
function rrelu(x, l = 1 / 8.0, u = 1 / 3.0)
97-
a = oftype(x / 1, (u - l) * rand() + l)
98+
function rrelu(x::T, l=1//8, u=1//3) where T<:Number
99+
a = (u - l) * rand(float(T)) + l
98100
return leakyrelu(x, a)
99101
end
100102

@@ -105,10 +107,9 @@ Exponential Linear Unit activation function.
105107
See [Fast and Accurate Deep Network Learning by Exponential Linear Units](https://arxiv.org/abs/1511.07289).
106108
You can also specify the coefficient explicitly, e.g. `elu(x, 1)`.
107109
"""
108-
elu(x, α=1) = ifelse(x 0, x/1, α * (exp(x) - 1))
109-
110-
deriv_elu(x, Ω, α=1) = ifelse(x 0, one(x), Ω + α)
110+
elu(x, α=1) = ifelse(x 0, float(x), α * (exp(x) - 1))
111111

112+
deriv_elu(Ω, α=1) = ifelse 0, 1, Ω + α)
112113

113114
"""
114115
gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3)))
@@ -117,11 +118,13 @@ deriv_elu(x, Ω, α=1) = ifelse(x ≥ 0, one(x), Ω + α)
117118
activation function.
118119
"""
119120
function gelu(x)
120-
λ = oftype(x / 1, (2 / π))
121-
α = oftype(x / 1, 0.044715)
121+
α = oftf(x, 0.044715)
122+
λ = oftf(x, gelu_λ)
122123
x/2 * (1 + tanh* (x + α * x^3)))
123124
end
124125

126+
const gelu_λ = (2 / π)
127+
125128
"""
126129
swish(x) = x * σ(x)
127130
@@ -148,15 +151,18 @@ Scaled exponential linear units.
148151
See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
149152
"""
150153
function selu(x)
151-
λ = oftype(x/1, 1.0507009873554804934193349852946)
152-
α = oftype(x/1, 1.6732632423543772848170429916717)
153-
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
154+
λ = oftf(x, selu_λ)
155+
α = oftf(x, selu_α)
156+
λ * ifelse(x > 0, x, α * (exp(x) - 1))
154157
end
155158

159+
const selu_λ = 1.0507009873554804934193349852946
160+
const selu_α = 1.6732632423543772848170429916717
161+
156162
function deriv_selu(Ω)
157-
λ = oftype/1, 1.0507009873554804934193349852946)
158-
α = oftype/1, 1.6732632423543772848170429916717)
159-
return ifelse> 0, λ, Ω + α*λ)
163+
λ = oftf(Ω, selu_λ)
164+
α = oftf(Ω, selu_α)
165+
ifelse> 0, λ, Ω + α * λ)
160166
end
161167

162168
"""
@@ -165,7 +171,7 @@ end
165171
Continuously Differentiable Exponential Linear Units
166172
See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/abs/1704.07483).
167173
"""
168-
celu(x, α=1) = ifelse(x 0, x/1, α * (exp(x/α) - 1))
174+
celu(x, α=1) = ifelse(x 0, float(x), α * (exp(x/α) - 1))
169175

170176
"""
171177
trelu(x, theta=1) = x > theta ? x : 0
@@ -174,14 +180,15 @@ Threshold Gated Rectified Linear.
174180
See [ThresholdRelu](https://arxiv.org/abs/1402.3337)
175181
"""
176182
trelu(x, theta=1) = ifelse(x > theta, x, zero(x))
183+
177184
const thresholdrelu = trelu
178185

179186
"""
180187
softsign(x) = x / (1 + |x|)
181188
182189
See [Quadratic Polynomials Learn Better Image Features](http://www.iro.umontreal.ca/~lisa/publications2/index.php/attachments/single/205).
183190
"""
184-
softsign(x) = x / (one(x) + abs(x))
191+
softsign(x) = x / (1 + abs(x))
185192

186193
"""
187194
softplus(x) = log(exp(x) + 1)
@@ -195,8 +202,9 @@ softplus(x) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
195202
196203
Return `log(cosh(x))` which is computed in a numerically stable way.
197204
"""
198-
logcosh(x) = x + softplus(-2x) - log(oftype(x, 2))
205+
logcosh(x) = x + softplus(-2x) - oftf(x, log2)
199206

207+
const log2 = log(2)
200208

201209
"""
202210
mish(x) = x * tanh(softplus(x))
@@ -219,7 +227,7 @@ tanhshrink(x) = x - tanh(x)
219227
220228
See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_Activation_Function).
221229
"""
222-
softshrink(x, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ)
230+
softshrink(x, λ=oftf(x, 0.5)) = min(max(0, x - λ), x + λ)
223231

224232
# Provide an informative error message if activation functions are called with an array
225233
for f in ACTIVATIONS
@@ -241,7 +249,7 @@ UNARY_ACTS = [ # f, df
241249
(:hardtanh, :(-1 < x < 1)),
242250
(:selu, :(deriv_selu(Ω))),
243251
(, :(conj* (1 - Ω)))),
244-
(:elu, :(deriv_elu(x, Ω))),
252+
(:elu, :(deriv_elu(Ω))),
245253
]
246254

247255
for (f, df) in UNARY_ACTS
@@ -260,7 +268,7 @@ end
260268

261269

262270
BINARY_ACTS = [ # f, df1, df2
263-
(:elu, :(deriv_elu(x1, Ω, x2)), :(DoesNotExist())), # TODO use real deriv instead of DNE
271+
(:elu, :(deriv_elu(Ω, x2)), :(DoesNotExist())), # TODO use real deriv instead of DNE
264272
]
265273

266274
for (f, df1, df2) in BINARY_ACTS

test/activations.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ ACTIVATION_FUNCTIONS =
66

77
function test_value_float_precision_preserving(a)
88
@testset "$(a): " begin
9-
for T in [Float32, Float64]
9+
for T in [Float16, Float32, Float64]
1010
for val in [-10, -1, 0, 1, 10]
1111
val = @inferred a(T(val))
1212
@test typeof(val) == T
@@ -28,7 +28,7 @@ end
2828

2929
function test_gradient_float_precision_preserving(a)
3030
@testset "$(a): " begin
31-
for T in [Float32, Float64]
31+
for T in [Float16, Float32, Float64]
3232
for val in [-10, -1, 0, 1, 10]
3333
val = @inferred a'(T(val))
3434
@test typeof(val) == T
@@ -61,7 +61,7 @@ end
6161
@test softshrink(0.0) == 0.0
6262

6363
@test sigmoid(1.0) == 1.0 / (1.0 + exp(-1.0))
64-
@test hardsigmoid(1.0) == max(0,min(1,0.2*1.0 + 0.5))
64+
@test hardsigmoid(1.0) == max(0,min(1, (1 + 3)/6))
6565
@test hardtanh(1.0) == 1.0
6666
@test relu(1.0) == 1.0
6767
@test leakyrelu(1.0) == 1.0
@@ -82,7 +82,7 @@ end
8282
@test softshrink(1.0) == 0.5
8383

8484
@test sigmoid(-1.0) == exp(-1.0) / (1.0 + exp(-1.0))
85-
@test hardsigmoid(-1.0) == max(0,min(1,0.2*-1.0 + 0.5))
85+
@test hardsigmoid(-1.0) == max(0,min(1,(-1+3)/6 ))
8686
@test hardtanh(-1.0) == -1.0
8787
@test relu(-1.0) == 0.0
8888
@test leakyrelu(-1.0) == -0.01
@@ -189,9 +189,8 @@ end
189189
@test logcosh(1_000.0) + log(2) == 1_000.0
190190

191191
@testset "hardsigmoid" begin
192-
@test hardsigmoid(0.3) == 0.56
193-
@test hardsigmoid(-0.3) == 0.44
194-
@test hardsigmoid(0.1,0.5) == 0.55
192+
@test hardsigmoid(0.3) == max(0,min(1,(0.3+3)/6))
193+
@test hardsigmoid(-0.3) == max(0,min(1,(-0.3+3)/6))
195194
for T in [:Float32, :Float64]
196195
@eval @test hardsigmoid.($T[-100_000, 100_000.]) $T[0., 1.]
197196
end
@@ -260,4 +259,3 @@ end
260259
gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
261260
gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
262261
end
263-

0 commit comments

Comments
 (0)