Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit aeba8c9

Browse files
authored
Wrapper functions for NNlib (#615)
1 parent a1a4dc0 commit aeba8c9

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed

src/nnlib.jl

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,26 @@
11
using NNlib
22

33
# Activation functions
4+
@cufunc softplus(x::Real) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
45

5-
@cufunc σ(x) = ifelse(x < -80, zero(x), one(x) / (one(x) + exp(-x)))
6+
@cufunc logσ(x::Real) = -softplus(-x)
67

7-
@cufunc function logσ(x)
8-
max_v = max(zero(x), -x)
9-
z = exp(-max_v) + exp(-x-max_v)
10-
-(max_v + log(z))
8+
@cufunc function gelu(x::Real)
9+
p = oftype(x / 1, π)
10+
λ = oftype(x / 1, (2 / p))
11+
α = oftype(x / 1, 0.044715)
12+
h = oftype(x / 1, 0.5)
13+
h * x * (one(x) + tanh* (x + α * x^3)))
1114
end
1215

13-
@cufunc elu(x, α = one(x)) =
14-
ifelse(x 0, x/1, α * (exp(x) - one(x)))
16+
@cufunc lisht(x::Real) = x * tanh(x)
1517

16-
@cufunc swish(x) = x * σ(x)
18+
@cufunc logcosh(x::Real) = x + softplus(-2x) - log(oftype(x, 2))
1719

18-
@cufunc function gelu(x)
19-
λ = oftype(x/1, (2/π))
20-
α = oftype(x/1, 0.044715)
21-
h = oftype(x/1, 0.5)
22-
h * x * (one(x) + tanh* (x + α * x^3)))
23-
end
20+
@cufunc mish(x::Real) = x * tanh(softplus(x))
2421

25-
@cufunc function selu(x)
26-
λ = oftype(x/1, 1.0507009873554804934193349852946)
27-
α = oftype(x/1, 1.6732632423543772848170429916717)
28-
λ * ifelse(x > 0, x/1, α * (exp(x) - 1))
29-
end
22+
@cufunc tanhshrink(x::Real) = x - tanh(x)
3023

31-
@cufunc softplus(x) = ifelse(x > 0, x + log1p(exp(-x)), log1p(exp(x)))
3224

3325

3426
# Batched matrix multiplication

test/dnn.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ end
8080
@test testf(CUDNN.cudnnActivationBackward, cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)), cu(rand(Float64, 10, 10, 3, 1)))
8181

8282
# activations defined in src/nnlib.jl
83+
ACTIVATION_FUNCTIONS = [σ, logσ, hardσ, hardtanh, relu, leakyrelu, relu6, rrelu,
84+
elu, gelu, celu, swish, lisht, selu, trelu, softplus,
85+
softsign, logcosh, mish, tanhshrink, softshrink];
8386
for dims in ((5,5), (5,))
84-
for f in (σ, logσ, elu, swish, gelu, selu, softplus)
87+
for f in filter(x -> x != rrelu, ACTIVATION_FUNCTIONS)
8588
@test testf(x -> f.(x), rand(Float64, dims))
8689
end
8790
end

0 commit comments

Comments
 (0)