Skip to content

Commit dfacc9c

Browse files
committed
Fix AlphaDropout implementation and add tests
Behaviour and outputs are adapted from the PyTorch and TF implementations
1 parent 66a84ef commit dfacc9c

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

src/layers/normalise.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,20 @@ mutable struct AlphaDropout{F}
101101
end
102102
end
103103

104-
function (a::AlphaDropout)(x)
104+
function (a::AlphaDropout)(x::AbstractArray{T}) where T
105105
_isactive(a) || return x
106-
λ = eltype(x)(1.0507009873554804934193349852946)
107-
α = eltype(x)(1.6732632423543772848170429916717)
108-
α1 = eltype(x)(-λ*α)
109-
noise = randn(eltype(x), size(x))
110-
x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p))
111-
A = sqrt(a.p + a.p * (1 - a.p) * α1^2)
112-
B = -A * α1 * (1 - a.p)
113-
x = @. A * x + B
114-
return x
106+
p = a.p
107+
iszero(p) && return x
108+
isone(p) && return sign.(x) .* T(0)
109+
110+
λ = T(1.0507009873554804934193349852946)
111+
α = T(1.6732632423543772848170429916717)
112+
α1 = T(-λ * α)
113+
A = inv(sqrt((1 - p) * (1 + p * α1^2)))
114+
B = -A * α1 * p
115+
116+
noise = rand!(similar(x))
117+
return A .* ifelse.(noise .> p, x, α1) .+ B
115118
end
116119

117120
testmode!(m::AlphaDropout, mode=true) =

test/layers/normalisation.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,34 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
5757
@test count(a->a == 0, y) == 0
5858
end
5959

60+
@testset "AlphaDropout" begin
61+
x = [1., 2., 3.]
62+
@test x == AlphaDropout(0.1)(x)
63+
@test x == evalwgrad(AlphaDropout(0), x)
64+
@test zero(x) == evalwgrad(AlphaDropout(1), x)
65+
66+
x = randn(1000) # large enough to prevent flaky test
67+
m = AlphaDropout(0.5)
68+
69+
y = evalwgrad(m, x)
70+
# Should preserve unit mean and variance
71+
@test mean(y) 0 atol=0.1
72+
@test var(y) 1 atol=0.1
73+
74+
testmode!(m, true) # should override istraining
75+
@test evalwgrad(m, x) == x
76+
77+
testmode!(m, false)
78+
y = evalwgrad(m, x)
79+
@test mean(y) 0 atol=0.1
80+
@test var(y) 1 atol=0.1
81+
82+
# Known good value ranges
83+
# Values taken from https://github.com/pytorch/pytorch/blob/v1.10.0/test/cpp/api/modules.cpp#L1337-L1338
84+
x = ones(100)
85+
@test 40 < sum(evalwgrad(m, x)) < 130
86+
end
87+
6088
@testset "BatchNorm" begin
6189
let m = BatchNorm(2), x = [1.0 3.0 5.0;
6290
2.0 4.0 6.0]

0 commit comments

Comments
 (0)