diff --git a/Project.toml b/Project.toml index 283f5069c9..1f376e96fc 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.13.0-DEV" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -26,6 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Adapt = "3.0" ArrayInterface = "3.1, 4" CUDA = "3" +ChainRulesCore = "1.12" Functors = "0.2.1" MLUtils = "0.1.4" MacroTools = "0.5" @@ -35,7 +37,7 @@ ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" StatsBase = "0.33" -Zygote = "0.6" +Zygote = "0.6.34" julia = "1.6" [extras] diff --git a/src/Flux.jl b/src/Flux.jl index 2b204567d0..5f906a0528 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,6 +12,7 @@ using MLUtils using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient +using ChainRulesCore export Chain, Dense, Maxout, SkipConnection, Parallel, RNN, LSTM, GRU, GRUv3, diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 0ef3c65308..6e18a066af 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,8 +3,7 @@ module CUDAint using ..CUDA import ..Flux: Flux -import Zygote -using Zygote: @adjoint +using ChainRulesCore import NNlib, NNlibCUDA include("cudnn.jl") diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 4a3b2618c8..a15637f77f 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -11,10 +11,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, training=Flux._isactive(BN))) end -@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...) +function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) function batchnorm_pullback(Δ) - ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing + grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...) + (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) end y, batchnorm_pullback end diff --git a/src/deprecations.jl b/src/deprecations.jl index e258f41897..c9648c0568 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -17,6 +17,13 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, # v0.13 deprecations +function Broadcast.broadcasted(f::Recur, args...) + # This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12 + Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order. + Re-writing this as a comprehension would be better.""", :broadcasted) + map(f, args...) # map isn't really safe either, but +end + @deprecate frequencies(xs) group_counts(xs) # Channel notation: Changed to match Conv, but very softly deprecated! diff --git a/src/functor.jl b/src/functor.jl index c030096468..b056ff9574 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -122,12 +122,12 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x -Zygote.@adjoint function Array(x::CUDA.CuArray) - Array(x), d -> (CUDA.cu(d),) +function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray) + Array(x), d -> (NoTangent(), CUDA.cu(d),) end -Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),) +function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) + adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),) end # CPU/GPU movement conveniences @@ -204,7 +204,7 @@ function check_use_cuda() end end end -Zygote.@nograd check_use_cuda +ChainRulesCore.@non_differentiable check_use_cuda() # Precision diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 519618e4be..eb0ea8604e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -273,8 +273,7 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) ) end -# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900 -@nograd conv_transpose_dims +ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) b = reshape(c.bias, map(_->1, c.stride)..., :, 1) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7553f8b03f..f92daed13a 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -1,6 +1,6 @@ istraining() = false -@adjoint istraining() = true, _ -> nothing +ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) _isactive(m) = isnothing(m.active) ? istraining() : m.active @@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true) end dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) -@adjoint function dropout(rng, x, p; dims=:, active::Bool=true) - active || return x, Δ -> (Δ, nothing) - y = dropout_mask(rng, x, p, dims=dims) - return x .* y, Δ -> (nothing, Δ .* y, nothing) -end - dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) dropout_mask(rng, x::CuArray, p; kwargs...) = throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) @@ -56,7 +50,7 @@ function _dropout_mask(rng, x, p; dims=:) end # TODO move this to NNlib -Zygote.ChainRulesCore.@non_differentiable dropout_mask(rng, x, p) +ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) """ Dropout(p; dims=:, rng = rng_from_array()) @@ -234,7 +228,8 @@ function _track_stats!( bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new return nothing end -Zygote.@nograd _track_stats! + +ChainRulesCore.@non_differentiable _track_stats!(::Any...) """ BatchNorm(channels::Integer, λ=identity; diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 02d9b07089..f899f32ab1 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :) # AD-friendly helper for dividing monolithic RNN params into equally sized gates multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N) -@adjoint function multigate(x::AbstractArray, h, c) +function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c) function multigate_pullback(dy) - dx = Zygote._zero(x, eltype(x)) - map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ - dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ)); + dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x) + foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ + dyᵢ isa AbstractZero && return + @. dxᵢ += dyᵢ end - return (dx, nothing, nothing) + return (NoTangent(), dx, NoTangent(), NoTangent()) end return multigate(x, h, c), multigate_pullback end @@ -379,8 +380,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 """ GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) Recur(m::GRUv3Cell) = Recur(m, m.state0) - - -@adjoint function Broadcast.broadcasted(f::Recur, args...) - Zygote.∇map(__context__, f, args...) -end diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 413c4ee034..50d94b9d21 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -3,6 +3,7 @@ module Losses using Statistics using Zygote using Zygote: @adjoint +using ChainRulesCore using ..Flux: ofeltype, epseltype using CUDA using NNlib: logsoftmax, logσ diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index e18a2ad88a..ed0a06101e 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -133,10 +133,10 @@ for mathematical details. """ ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss -@adjoint function ctc_loss(ŷ, y) - out = ctc_alpha(ŷ, y) - ctc_loss_pullback(Δ) = (Δ .* ∇ctc_loss(ŷ, y, out), nothing) - return out.loss, ctc_loss_pullback +function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y) + tmp = ctc_alpha(ŷ, y) + ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent()) + return tmp.loss, ctc_loss_pullback end diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 386cd67166..e13a3e6206 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -23,6 +23,9 @@ end res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) end +ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting +ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true) + # This can be made an error in Flux v0.13, for now just a warning function _check_sizes(ŷ::AbstractArray, y::AbstractArray) for d in 1:max(ndims(ŷ), ndims(y)) @@ -33,4 +36,4 @@ function _check_sizes(ŷ::AbstractArray, y::AbstractArray) end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 -Zygote.@nograd _check_sizes +ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) diff --git a/src/onehot.jl b/src/onehot.jl index 36345438a3..86afd513dc 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -230,7 +230,11 @@ function _fast_argmax(x::OneHotLike) end end -@nograd OneHotArray, onecold, onehot, onehotbatch +ChainRulesCore.@non_differentiable onehot(::Any...) +ChainRulesCore.@non_differentiable onehotbatch(::Any...) +ChainRulesCore.@non_differentiable onecold(::Any...) + +ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer) function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L _isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B) diff --git a/src/utils.jl b/src/utils.jl index b5edbad5e6..57a4f8114b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -472,7 +472,7 @@ function _restructure(m, xs) return m̄ end -@adjoint function _restructure(m, xs) +@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule m̄, numel = _restructure(m, xs), length(xs) function _restructure_pullback(dm) xs′ = destructure(dm)[1] @@ -603,7 +603,10 @@ true """ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)] -@nograd modules +@nograd modules # TODO: is this correct? might fail with explicit parameters. +function ChainRulesCore.rrule(::typeof(modules), m) + modules(m), dm -> error("Flux.modules is not at present differentiable, sorry") +end isleaflike(x) = Functors.isleaf(x) isleaflike(::Tuple{Vararg{<:Number}}) = true