Skip to content

Commit 34f787d

Browse files
author
ShoofLLC
committed
Calling rand with float type arguments
Replacing the call to rand! in the _dropout_mask function to account for complex (float) data types.
1 parent d0a5b77 commit 34f787d

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/layers/normalise.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ dropout_mask(rng, x::CuArray, p; kwargs...) =
4949
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
5050
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
5151
function _dropout_mask(rng, x, p; dims=:)
52-
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
52+
fptype = float(real(eltype(x)))
53+
y = rand!(rng, similar(x, fptype, _dropout_shape(x, dims)))
5354
y .= _dropout_kernel.(y, p, 1 - p)
5455
return y
5556
end

test/layers/normalisation.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
55

66
@testset "Dropout" begin
77
@testset for rng_kwargs in ((), (; rng = MersenneTwister()))
8+
x = [1.0+0im,2.0+1im,3.0+3im]
9+
@test x == Dropout(0.1; rng_kwargs...)(x)
10+
@test x == evalwgrad(Dropout(0; rng_kwargs...), x)
11+
@test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x)
12+
813
x = [1.,2.,3.]
914
@test x == Dropout(0.1; rng_kwargs...)(x)
1015
@test x == evalwgrad(Dropout(0; rng_kwargs...), x)

0 commit comments

Comments
 (0)