diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 146b7dba56..73b625c616 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -30,14 +30,14 @@ The [`Dropout`](@ref) layer is what you should use in most scenarios. """ function dropout(x, p; dims=:, active::Bool=true) active || return x - y = dropout_mask(x, p, dims=dims) - return x .* y + y = rand!(similar(x, _dropout_shape(x, dims))) + x .* _dropout_kernel.(y, p, 1-p) end @adjoint function dropout(x, p; dims=:, active::Bool=true) active || return x, Δ -> (Δ, nothing) - y = dropout_mask(x, p, dims=dims) - return x .* y, Δ -> (Δ .* y, nothing) + y = rand!(similar(x, _dropout_shape(x, dims))) + return x .* _dropout_kernel.(y, p, 1-p), Δ -> (Δ .* _dropout_kernel.(y, p, 1-p), nothing) end function dropout_mask(x, p; dims=:)