Skip to content

Commit 8d3b8d3

Browse files
authored
Merge pull request #1849 from darsnack/darsnack/dropout-rng
Add RNG support for Dropout/AlphaDropout
2 parents 7467e6b + f922c16 commit 8d3b8d3

File tree

7 files changed

+196
-95
lines changed

7 files changed

+196
-95
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Flux Release Notes
22

3+
## v0.12.10
4+
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
5+
36
## v0.12.9
47
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
58
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).

src/functor.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ end
9696
struct FluxCUDAAdaptor end
9797
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9898
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
99+
if VERSION >= v"1.7"
100+
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
101+
else
102+
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
103+
end
104+
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
105+
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
106+
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")
99107

100108
# TODO: figure out the correct design for OneElement
101109
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
@@ -109,6 +117,8 @@ adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x
109117
adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
110118
adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x
111119
adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
120+
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
121+
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
112122

113123
Zygote.@adjoint function Array(x::CUDA.CuArray)
114124
Array(x), d -> (CUDA.cu(d),)
@@ -149,6 +159,9 @@ _isbitsarray(::AbstractArray{<:Number}) = true
149159
_isbitsarray(::AbstractArray{T}) where T = isbitstype(T)
150160
_isbitsarray(x) = false
151161

162+
_isleaf(::AbstractRNG) = true
163+
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
164+
152165
"""
153166
gpu(x)
154167
@@ -174,7 +187,7 @@ CuArray{Float32, 2}
174187
"""
175188
function gpu(x)
176189
check_use_cuda()
177-
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isbitsarray) : x
190+
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
178191
end
179192

180193
function check_use_cuda()

src/layers/normalise.jl

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
1010
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
1111

1212
"""
13-
dropout(x, p; dims=:, active=true)
13+
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true)
1414
1515
The dropout function. If `active` is `true`,
1616
for each input, either sets that input to `0` (with probability
@@ -20,6 +20,9 @@ This is used as a regularisation, i.e. it reduces overfitting during training.
2020
2121
If `active` is `false`, it just returns the input `x`.
2222
23+
Specify `rng` for custom RNGs instead of the default RNG.
24+
Note that custom RNGs are only supported on the CPU.
25+
2326
Warning: when using this function, you have to manually manage the activation
2427
state. Usually in fact, dropout is used while training
2528
but is deactivated in the inference phase. This can be
@@ -28,49 +31,63 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
2831
2932
The [`Dropout`](@ref) layer is what you should use in most scenarios.
3033
"""
31-
function dropout(x, p; dims=:, active::Bool=true)
34+
function dropout(rng, x, p; dims=:, active::Bool=true)
3235
active || return x
33-
y = dropout_mask(x, p, dims=dims)
36+
y = dropout_mask(rng, x, p, dims=dims)
3437
return x .* y
3538
end
39+
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
3640

37-
@adjoint function dropout(x, p; dims=:, active::Bool=true)
41+
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
3842
active || return x, Δ -> (Δ, nothing)
39-
y = dropout_mask(x, p, dims=dims)
40-
return x .* y, Δ ->.* y, nothing)
43+
y = dropout_mask(rng, x, p, dims=dims)
44+
return x .* y, Δ -> (nothing, Δ .* y, nothing)
4145
end
4246

43-
function dropout_mask(x, p; dims=:)
44-
y = rand!(similar(x, _dropout_shape(x, dims)))
47+
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
48+
dropout_mask(rng, x::CuArray, p; kwargs...) =
49+
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
50+
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
51+
function _dropout_mask(rng, x, p; dims=:)
52+
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
4553
y .= _dropout_kernel.(y, p, 1 - p)
4654
return y
4755
end
4856

4957
"""
50-
Dropout(p; dims=:)
58+
Dropout(p; dims=:, rng = rng_from_array())
5159
5260
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
5361
5462
To apply dropout along certain dimension(s), specify the `dims` keyword.
5563
e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
5664
(also called 2D dropout).
5765
66+
Specify `rng` to use a custom RNG instead of the default.
67+
Custom RNGs are only supported on the CPU.
68+
5869
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
5970
"""
60-
mutable struct Dropout{F,D}
71+
mutable struct Dropout{F,D,R<:AbstractRNG}
6172
p::F
6273
dims::D
6374
active::Union{Bool, Nothing}
75+
rng::R
6476
end
77+
Dropout(p, dims, active) = Dropout(p, dims, active, rng_from_array())
6578

66-
function Dropout(p; dims=:)
79+
function Dropout(p; dims=:, rng = rng_from_array())
6780
@assert 0 p 1
68-
Dropout(p, dims, nothing)
81+
Dropout(p, dims, nothing, rng)
6982
end
7083

84+
@functor Dropout
85+
86+
trainable(a::Dropout) = ()
87+
7188
function (a::Dropout)(x)
7289
_isactive(a) || return x
73-
return dropout(x, a.p; dims=a.dims, active=true)
90+
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
7491
end
7592

7693
testmode!(m::Dropout, mode=true) =
@@ -83,7 +100,7 @@ function Base.show(io::IO, d::Dropout)
83100
end
84101

85102
"""
86-
AlphaDropout(p)
103+
AlphaDropout(p; rng = rng_from_array())
87104
88105
A dropout layer. Used in
89106
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -92,14 +109,21 @@ remain the same as before.
92109
93110
Does nothing to the input once [`testmode!`](@ref) is true.
94111
"""
95-
mutable struct AlphaDropout{F}
112+
mutable struct AlphaDropout{F,R<:AbstractRNG}
96113
p::F
97114
active::Union{Bool, Nothing}
98-
function AlphaDropout(p, active = nothing)
115+
rng::R
116+
function AlphaDropout(p, active, rng)
99117
@assert 0 p 1
100-
new{typeof(p)}(p, active)
118+
new{typeof(p), typeof(rng)}(p, active, rng)
101119
end
102120
end
121+
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
122+
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)
123+
124+
@functor AlphaDropout
125+
126+
trainable(a::AlphaDropout) = ()
103127

104128
function (a::AlphaDropout)(x::AbstractArray{T}) where T
105129
_isactive(a) || return x
@@ -111,7 +135,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
111135
A = T(inv(sqrt((1 - p) * (1 + p * α′^2))))
112136
B = T(-A * α′ * p)
113137

114-
noise = rand!(similar(x))
138+
noise = rand!(a.rng, similar(x))
115139
return A .* ifelse.(noise .> p, x, α′) .+ B
116140
end
117141

src/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of con
3333
ofeltype(x, y) = convert(float(eltype(x)), y)
3434
epseltype(x) = eps(float(eltype(x)))
3535

36+
"""
37+
rng_from_array([x])
38+
39+
Create an instance of the RNG most appropriate for `x`.
40+
The current defaults are:
41+
- `x isa AbstractArray`
42+
- Julia version is < 1.7: `Random.GLOBAL_RNG`
43+
- Julia version is >= 1.7: `Random.default_rng()`
44+
- `x isa CuArray`: `CUDA.default_rng()`
45+
When `x` is unspecified, it is assumed to be a `AbstractArray`.
46+
"""
47+
rng_from_array(::AbstractArray) = rng_from_array()
48+
rng_from_array(::CuArray) = CUDA.default_rng()
49+
if VERSION >= v"1.7"
50+
rng_from_array() = Random.default_rng()
51+
else
52+
rng_from_array() = Random.GLOBAL_RNG
53+
end
54+
3655
"""
3756
glorot_uniform([rng=GLOBAL_RNG], dims...)
3857

test/cuda/layers.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,13 @@ end
280280
end
281281
end
282282
end
283+
284+
@testset "Dropout RNGs" begin
285+
@test_throws ArgumentError Flux.dropout(MersenneTwister(), CUDA.rand(Float32, 2, 3), 0.1)
286+
@testset for layer in (Dropout, AlphaDropout)
287+
m = layer(0.1; rng = MersenneTwister(123))
288+
@test_throws ErrorException gpu(m)
289+
m = layer(0.1; rng = CUDA.default_rng())
290+
@test gpu(m).rng isa CUDA.RNG
291+
end
292+
end

test/cuda/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Flux, Test, CUDA
22
using Zygote
33
using Zygote: pullback
4+
using Random
45

56
@info "Testing GPU Support"
67
CUDA.allowscalar(false)

0 commit comments

Comments
 (0)