Skip to content

Commit df4019d

Browse files
committed
replace at-adjoint with rrule
1 parent 9b21e2c commit df4019d

File tree

13 files changed

+43
-33
lines changed

13 files changed

+43
-33
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
10+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1011
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
1112
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
1213
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
@@ -33,6 +34,7 @@ AbstractTrees = "0.3"
3334
Adapt = "3.0"
3435
ArrayInterface = "3.1, 4"
3536
CUDA = "3"
37+
ChainRulesCore = "1.12"
3638
CodecZlib = "0.7"
3739
Colors = "0.12"
3840
Functors = "0.2.1"
@@ -43,7 +45,7 @@ ProgressLogging = "0.1"
4345
Reexport = "0.2, 1.0"
4446
StatsBase = "0.33"
4547
ZipFile = "0.9"
46-
Zygote = "0.6"
48+
Zygote = "0.6.34"
4749
julia = "1.6"
4850

4951
[extras]

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MacroTools: @forward
99
@reexport using NNlib
1010
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1111
export gradient
12+
using ChainRulesCore
1213

1314
export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
1415
RNN, LSTM, GRU, GRUv3,

src/cuda/cuda.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module CUDAint
33
using ..CUDA
44

55
import ..Flux: Flux
6-
import Zygote
7-
using Zygote: @adjoint
6+
# import Zygote
7+
# using Zygote: @adjoint
8+
using ChainRulesCore
89
import NNlib, NNlibCUDA
910

1011
include("cudnn.jl")

src/cuda/cudnn.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
1111
training=Flux._isactive(BN)))
1212
end
1313

14-
@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
14+
function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...)
1515
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
1616
function batchnorm_pullback(Δ)
17-
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
17+
grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)
18+
(NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent())
1819
end
1920
y, batchnorm_pullback
2021
end

src/functor.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
120120
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
121121
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
122122

123-
Zygote.@adjoint function Array(x::CUDA.CuArray)
124-
Array(x), d -> (CUDA.cu(d),)
123+
function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray)
124+
Array(x), d -> (NoTangent(), CUDA.cu(d),)
125125
end
126126

127-
Zygote.@adjoint function Adapt.adapt_storage(to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
128-
adapt_storage(to, x), d -> (nothing, adapt_storage(FluxCUDAAdaptor(), d),)
127+
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
128+
adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),)
129129
end
130130

131131
# CPU/GPU movement conveniences
@@ -202,7 +202,7 @@ function check_use_cuda()
202202
end
203203
end
204204
end
205-
Zygote.@nograd check_use_cuda
205+
ChainRulesCore.@non_differentiable check_use_cuda()
206206

207207
# Precision
208208

src/layers/conv.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,7 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
275275
)
276276
end
277277

278-
# TODO: Find proper fix for https://github.com/FluxML/Flux.jl/issues/900
279-
@nograd conv_transpose_dims
278+
ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
280279

281280
function (c::ConvTranspose)(x::AbstractArray)
282281
b = reshape(c.bias, map(_->1, c.stride)..., :, 1)

src/layers/normalise.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
istraining() = false
22

3-
@adjoint istraining() = true, _ -> nothing
3+
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

55
_isactive(m) = isnothing(m.active) ? istraining() : m.active
66

@@ -38,12 +38,6 @@ function dropout(rng, x, p; dims=:, active::Bool=true)
3838
end
3939
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
4040

41-
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
42-
active || return x, Δ -> (Δ, nothing)
43-
y = dropout_mask(rng, x, p, dims=dims)
44-
return x .* y, Δ -> (nothing, Δ .* y, nothing)
45-
end
46-
4741
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
4842
dropout_mask(rng, x::CuArray, p; kwargs...) =
4943
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
@@ -54,6 +48,8 @@ function _dropout_mask(rng, x, p; dims=:)
5448
return y
5549
end
5650

51+
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
52+
5753
"""
5854
Dropout(p; dims=:, rng = rng_from_array())
5955
@@ -230,7 +226,8 @@ function _track_stats!(
230226
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
231227
return nothing
232228
end
233-
Zygote.@nograd _track_stats!
229+
230+
ChainRulesCore.@non_differentiable _track_stats!(::Any...)
234231

235232
"""
236233
BatchNorm(channels::Integer, λ=identity;

src/layers/recurrent.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)
66
# AD-friendly helper for dividing monolithic RNN params into equally sized gates
77
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)
88

9-
@adjoint function multigate(x::AbstractArray, h, c)
9+
function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c)
1010
function multigate_pullback(dy)
11-
dx = Zygote._zero(x, eltype(x))
12-
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13-
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
11+
dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x)
12+
foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
13+
dyᵢ isa AbstractZero && return
14+
@. dxᵢ += dyᵢ
1415
end
15-
return (dx, nothing, nothing)
16+
return (NoTangent(), dx, NoTangent(), NoTangent())
1617
end
1718
return multigate(x, h, c), multigate_pullback
1819
end
@@ -435,7 +436,7 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10
435436
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
436437
Recur(m::GRUv3Cell) = Recur(m, m.state0)
437438

438-
439+
# TODO move to ChainRulesCore?
439440
@adjoint function Broadcast.broadcasted(f::Recur, args...)
440441
Zygote.∇map(__context__, f, args...)
441442
end

src/losses/Losses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module Losses
33
using Statistics
44
using Zygote
55
using Zygote: @adjoint
6+
using ChainRulesCore
67
using ..Flux: ofeltype, epseltype
78
using CUDA
89
using NNlib: logsoftmax, logσ

src/losses/ctc.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,9 @@ for mathematical details.
133133
"""
134134
ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss
135135

136-
@adjoint function ctc_loss(ŷ, y)
137-
out = ctc_alpha(ŷ, y)
138-
ctc_loss_pullback(Δ) =.* ∇ctc_loss(ŷ, y, out), nothing)
139-
return out.loss, ctc_loss_pullback
136+
function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y)
137+
ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, out), NoTangent())
138+
return ctc_loss(ŷ, y), ctc_loss_pullback
140139
end
141140

142141

0 commit comments

Comments
 (0)