diff --git a/NEWS.md b/NEWS.md index 856208689c..a9db7cfa58 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,10 +3,10 @@ ## v0.13 * After a deprecations cycle, the datasets in `Flux.Data` have been removed in favour of MLDatasets.jl. -* `params` is not exported anymore since it is a common name and is also exported by Distributions.jl +* `params` is not exported anymore since it is a common name and is also exported by Distributions.jl * `flatten` is not exported anymore due to clash with Iterators.flatten. * Remove Juno.jl progress bar support as it is now obsolete. -* Improved compatibility of Dropout with Int and Complex types. +* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable. ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838) diff --git a/Project.toml b/Project.toml index 3eff2752eb..bc7bbec4b2 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ ArrayInterface = "3.1, 4" CUDA = "3" Functors = "0.2.1" MacroTools = "0.5" -NNlib = "0.8" +NNlib = "0.8.2" NNlibCUDA = "0.2" ProgressLogging = "0.1" Reexport = "0.2, 1.0" diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 29bcc8a6e4..0a18bf3fe1 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -55,6 +55,9 @@ function _dropout_mask(rng, x, p; dims=:) return y end +# TODO move this to NNlib +Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p) + """ Dropout(p; dims=:, rng = rng_from_array()) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 11f95f023a..ca8e15a643 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -241,7 +241,7 @@ import Flux: activations Parallel(f_cnt, sin)(1) @test CNT[] == 3 end - + # Ref https://github.com/FluxML/Flux.jl/issues/1673 @testset "Input domain" begin struct Input @@ -278,7 +278,7 @@ import Flux: activations vocab_size, embed_size = 10, 4 m = Flux.Embedding(vocab_size, embed_size) @test size(m.weight) == (embed_size, vocab_size) - + x = rand(1:vocab_size, 3) y = m(x) @test y isa Matrix{Float32} @@ -315,7 +315,7 @@ end # https://github.com/FluxML/NNlib.jl/issues/362 m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2)) x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3) - @test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) + @test Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) end @testset "gradients of Chain{Vector}" begin diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index c19cb1c088..7ae15aeff9 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -273,7 +273,7 @@ end x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == true - @test length(Flux.params(m)) == 2 + @test length(Flux.params(m)) == 2 x = Float64.(x) y = m(x) μ = mean(x, dims=1) @@ -287,7 +287,7 @@ end x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == false @test length(Flux.params(m)) == 0 - + x = Float64.(x) y = m(x) μ = mean(x, dims=1) @@ -458,3 +458,8 @@ end @test BN(x) ≈ GN(x) end end + +@testset "second derivatives" begin + m1 = Dropout(0.5) + @test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3) +end