Skip to content

Commit 1930966

Browse files
authored
Merge pull request #1867 from ShoofLLC/master
Updated Dropout for more input types.
2 parents 7b56813 + ccb328c commit 1930966

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ been removed in favour of MLDatasets.jl.
66
* `params` is not exported anymore since it is a common name and is also exported by Distributions.jl
77
* `flatten` is not exported anymore due to clash with Iterators.flatten.
88
* Remove Juno.jl progress bar support as it is now obsolete.
9+
* Improved compatibility of Dropout with Int and Complex types.
910

1011
## v0.12.10
1112
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

src/layers/normalise.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ dropout_mask(rng, x::CuArray, p; kwargs...) =
4949
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
5050
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
5151
function _dropout_mask(rng, x, p; dims=:)
52-
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
52+
realfptype = float(real(eltype(x)))
53+
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
5354
y .= _dropout_kernel.(y, p, 1 - p)
5455
return y
5556
end

test/layers/normalisation.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
55

66
@testset "Dropout" begin
77
@testset for rng_kwargs in ((), (; rng = MersenneTwister()))
8+
x = [1.0+0im,2.0+1im,3.0+3im]
9+
@test x == Dropout(0.1; rng_kwargs...)(x)
10+
@test x == evalwgrad(Dropout(0; rng_kwargs...), x)
11+
@test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x)
12+
813
x = [1.,2.,3.]
914
@test x == Dropout(0.1; rng_kwargs...)(x)
1015
@test x == evalwgrad(Dropout(0; rng_kwargs...), x)

0 commit comments

Comments
 (0)