Skip to content

Commit f5d25e5

Browse files
Add more Duplicated methods for Enzyme.jl support (#2471)
* add more Duplicated methods * update macro to zero, show * make informative errors if you use Duplicated without loading Enzyme * note on macro * fix some tests * add an Enzyme docs page * tweaks & tests * typos * news, docs * let Flux own the function update! to avoid piracy * Revert "let Flux own the function update! to avoid piracy" This reverts commit ca5a20f. * demand Optimisers PR * fixup * force depwarns * allow aux in withgradient * disallow Active * disallow trivial Duplicated * don't use ReverseWithPrimal in gradient * tweak * giant post-rebase fixup after everything was moved around... all earlier commits are a mess now, probably * clean up more rebase mess * fix docs * try out Ref for withgradient * don't own `_make_zero!` * add explicit errors for 2nd order * more rebase problems Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> * teach Flux.state about Duplicated * another explicit error for Zygote mistake * ahem * don't use Enzyme's make_zero!, fix some bugs * maybe this works? * see if CI likes these * turns out train! does have tests * enzyme tests * fix tests? * minor comments --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent cb76e9d commit f5d25e5

File tree

15 files changed

+633
-36
lines changed

15 files changed

+633
-36
lines changed

NEWS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5-
## v0.15.0
5+
## v0.15.0 (December 2024)
66
* Recurrent layers have undergone a complete redesign in [PR 2500](https://github.com/FluxML/Flux.jl/pull/2500).
77
* `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing: `rnncell(x_t, h_t) -> h_{t+1}`.
88
* `RNN`, `LSTM`, and `GRU` no longer store the hidden state internally, it has to be explicitely passed to the layer. Moreover, they now process entire sequences at once, rather than one element at a time: `rnn(x, h) -> h′`.
@@ -12,6 +12,8 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
1212
Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change.
1313
The module is still available for now, but will be removed in a future release.
1414
* Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible.
15+
* Further support for Enzyme.jl, via methods of `Flux.gradient(loss, Duplicated(model))`.
16+
Flux now owns & exports `gradient`, but without `Duplicated` this still defaults to calling Zygote.jl.
1517
* `Flux.params` has been deprecated. Use Zygote's explicit differentiation instead,
1618
`gradient(m -> loss(m, x, y), model)`, or use `Flux.trainables(model)` to get the trainable parameters.
1719
* Flux now requires Functors.jl v0.5. This new release of Functors assumes all types to be functors by default. Therefore, applying `@layer` or `@functor` to a type is no longer strictly necessary for Flux's models. However, it is still recommended to use `@layer Model` for additional functionality like pretty printing.
@@ -40,7 +42,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
4042
* After a deprecations cycle, the macro `@epochs` and the functions `Flux.stop`, `Flux.skip`, `Flux.zeros`, `Flux.ones` have been removed.
4143

4244
## v0.13.17
43-
* Apple's Metal GPU acceleration preliminary support via the extension mechanism.
45+
* Apple's Metal GPU acceleration preliminary support via the extension mechanism.
4446

4547
## v0.13.16
4648
* Most greek-letter keyword arguments are deprecated in favour of ascii.

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.15.0-DEV"
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
@@ -48,14 +49,15 @@ ChainRulesCore = "1.12"
4849
Compat = "4.10.0"
4950
Enzyme = "0.13"
5051
Functors = "0.5"
52+
EnzymeCore = "0.7.7, 0.8.4"
5153
MLDataDevices = "1.4.2"
5254
MLUtils = "0.4"
5355
MPI = "0.20.19"
5456
MacroTools = "0.5"
5557
NCCL = "0.1.1"
5658
NNlib = "0.9.22"
5759
OneHotArrays = "0.2.4"
58-
Optimisers = "0.4"
60+
Optimisers = "0.4.1"
5961
Preferences = "1"
6062
ProgressLogging = "0.1"
6163
Reexport = "1.0"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ makedocs(
3636
"Flat vs. Nested" => "reference/destructure.md",
3737
"Callback Helpers" => "reference/training/callbacks.md",
3838
"Gradients -- Zygote.jl" => "reference/training/zygote.md",
39+
"Gradients -- Enzyme.jl" => "reference/training/enzyme.md",
3940
"Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md",
4041
"Batching Data -- MLUtils.jl" => "reference/data/mlutils.md",
4142
"OneHotArrays.jl" => "reference/data/onehot.md",

docs/src/reference/training/enzyme.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
# [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme)
3+
4+
[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation.
5+
Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`.
6+
But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can [read about here](https://proceedings.nips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf)].
7+
It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.
8+
9+
Flux now builds in support for this, using Enzyme's own `Duplicated` type.
10+
Calling `Duplicated` on any Flux model which was defined using `@layer` will allocate space for the gradient,
11+
and passing that to `gradient` (or `withgradient`, or `train!`) will then use Enzyme instead of Zygote.
12+
The gradient functions still return the gradient as usual, which can then be passed to `update!`:
13+
14+
```julia
15+
julia> using Flux, Enzyme
16+
17+
julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax); # from model zoo
18+
19+
julia> dup_model = Enzyme.Duplicated(model) # this allocates space for the gradient
20+
Duplicated(
21+
Chain(
22+
Dense(784 => 32, σ), # 25_120 parameters
23+
Dense(32 => 10), # 330 parameters
24+
NNlib.softmax,
25+
),
26+
# norm(∇) ≈ 0.0f0
27+
) # Total: 4 arrays, 25_450 parameters, 199.391 KiB.
28+
29+
julia> x1 = randn32(28*28, 1); # fake image
30+
31+
julia> y1 = [i==3 for i in 0:9]; # fake label
32+
33+
julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1)) # uses Enzyme
34+
((layers = ((weight = Float32[-0.010354728 0.032972857
35+
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
36+
```
37+
38+
The gradient returned here is also stored within `dup_model`.
39+
Both share the same arrays -- what is returned is not a copy, just a view of the same memory (wrapped in `NamedTuple`s instead of `struct`s).
40+
They will all be set to zero when you call `gradient` again, then replaced with the new values.
41+
Alternatively, `gradient(f, args...; zero=false)` will add the new gradient to what's already stored.
42+
43+
Writing `Const(x1)` is optional, just plain `x1` is implicitly constant.
44+
Any set of `Duplicated` and `Const` arguments may appear in any order, so long as there is at least one `Duplicated`.
45+
46+
The gradient `grads_f[1]` can be passed to `update!` as usual.
47+
But for convenience, you may also use what is stored within `Duplicated`.
48+
These are equivalent ways to perform an update step:
49+
50+
```julia
51+
julia> opt_state = Flux.setup(Adam(), model)
52+
53+
julia> ans == Flux.setup(Adam(), dup_model)
54+
55+
julia> Flux.update!(opt_state, model, grads_f[1]) # exactly as for Zygote gradients
56+
57+
julia> Flux.update!(opt_state, dup_model) # equivlent new path, Enzyme only
58+
```
59+
60+
Instead of using these FLux functions, you can also use Enzyme's own functions directly.
61+
`Enzyme.gradient` works like this:
62+
63+
```julia
64+
julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1))
65+
(Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing)
66+
67+
julia> grads_f[1].layers[2].bias grads_e[1].layers[2].bias
68+
true
69+
```
70+
71+
Note that what `Enzyme.gradient` returns is an object like `deepcopy(model)` of the same type, `grads_e[1] isa Chain`.
72+
But its fields contain the same gradient.
73+
74+
There is also a method of `train!` which similarly takes `Duplicated(model)`:
75+
76+
```julia
77+
julia> opt_state = Flux.setup(Adam(0), model);
78+
79+
julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
80+
```
81+
82+
## Second-order AD
83+
84+
If you calculate a gradient within the loss function, then training will involve 2nd derivatives.
85+
While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.
86+
87+
## Listing
88+
89+
```@docs
90+
Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
91+
Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
92+
Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
93+
```
94+
95+
Enzyme.jl has [its own extensive documentation](https://enzymead.github.io/Enzyme.jl/stable/).

docs/src/reference/training/zygote.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ CollapsedDocStrings = true
44

55
# [Automatic Differentiation using Zygote.jl](@id autodiff-zygote)
66

7-
Flux re-exports the `gradient` from [Zygote](https://github.com/FluxML/Zygote.jl), and uses this function within [`train!`](@ref Flux.train!) to differentiate the model. Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).
7+
Flux's `gradient` function uses [Zygote](https://github.com/FluxML/Zygote.jl) by default, and also uses this function within [`train!`](@ref Flux.train!) to differentiate the model.
8+
Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).
89

10+
Flux also has support for Enzyme.jl, documented [on its own page](@ref autodiff-enzyme).
911

1012
## Explicit style
1113

ext/FluxEnzymeExt/FluxEnzymeExt.jl

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,117 @@
11
module FluxEnzymeExt
22

33
using Flux
4-
import Flux.Train: train!, _rule_to_state
4+
import Flux.Train: _enzyme_train!
5+
56
import Optimisers
7+
import Functors
68
import Enzyme
7-
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
9+
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed
10+
using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
811
using ProgressLogging: @withprogress, @logprogress
912

10-
_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
11-
_make_zero_internal!(x) = x
12-
_make_zero!(model) = fmap(_make_zero_internal!, model)
13+
EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
1314

14-
_applyloss(loss, model, d...) = loss(model, d...)
15+
### gradient & withgradient
1516

16-
EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
17+
# We can't use Enzyme.make_zero! to reset Duplicated, as it complains about e.g. LayerNorm having immutable differentiable fields
18+
# After https://github.com/EnzymeAD/Enzyme.jl/pull/1961 probably this can be `make_zero!(Ref(dup.dval))`
19+
_make_zero!(model) = Functors.fmapstructure(_make_zero_inner!, model)
20+
function _make_zero_inner!(x::AbstractArray{<:Number})
21+
Optimisers.isnumeric(x) || return
22+
Optimisers.maywrite(x) || error("can't handle this")
23+
fill!(x, zero(eltype(x)))
24+
nothing
25+
end
26+
_make_zero_inner!(x) = nothing # any other Functors leaf type
27+
28+
#= # This _make_zero! matches what Flux allows elsewhere:
29+
julia> Flux.setup(Adam(), (1:3.)')
30+
ERROR: model must be fully mutable for `train!` to work, got `x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}`.
31+
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) = true`
32+
=#
33+
# Perhaps canonical way for Enzyme is more like this:
34+
# function _make_zero!(x::AbstractArray{<:Number})
35+
# if Enzyme.guess_activity(typeof(x), Reverse) <: Duplicated
36+
# fill!(x, zero(eltype(x)))
37+
# elseif Enzyme.guess_activity(typeof(x), Reverse) <: Const
38+
# # that's OK
39+
# else
40+
# error("not sure what it should do for Active?")
41+
# end
42+
# end
1743

18-
function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
19-
train!(loss, model, data, _rule_to_state(model, rule); cb)
44+
function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
45+
for x in args
46+
zero && x isa Duplicated && _make_zero!(x.dval)
47+
_check_mutable(x)
48+
end
49+
Enzyme.autodiff(Reverse, Const(f), Active, args...)
50+
map(_grad_or_nothing, args)
2051
end
2152

22-
function train!(loss, model::Duplicated, data, opt; cb = nothing)
53+
_check_mutable(x::Const) = nothing
54+
_check_mutable(x::Duplicated) = Functors.anymutable(x) || error(
55+
"""`Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays."""
56+
)
57+
58+
# This function strips the returned gradient to be Zygote-like:
59+
_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
60+
_grad_or_nothing(::Const) = nothing
61+
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing
62+
63+
function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
64+
for x in args
65+
zero && x isa Duplicated && _make_zero!(x.dval)
66+
_check_mutable(x)
67+
end
68+
69+
# Take I, doesn't allow for aux at all.
70+
# _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
71+
72+
# Take II, using split mode.
73+
forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
74+
tape, result, shadow_result = forward(Const(f), args...)
75+
reverse(Const(f), args..., _sensitivity(result), tape)
76+
77+
# Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
78+
# dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
79+
# # result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
80+
# _, result = autodiff(ReverseWithPrimal, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
81+
82+
(; val = result, grad = map(_grad_or_nothing, args))
83+
end
84+
85+
@inline _sensitivity(y::Real) = one(y)
86+
@inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...)
87+
@inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys)))
88+
_sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
89+
or else a Tuple or NamedTuple whose first element is a real number.""")
90+
91+
function _ref_loss!(out::Ref, f, args...) # for Take III above
92+
val = f(args...)
93+
out[] = _get_loss(val) # saves loss by mutation
94+
val # returns the whole thing
95+
end
96+
97+
@inline _get_loss(y::Real) = y
98+
@inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1]
99+
@inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1]
100+
_get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
101+
or else a Tuple or NamedTuple whose first element is a real number.""")
102+
103+
### Flux.Train, for train!
104+
105+
_applyloss(loss, model, d...) = loss(model, d...)
106+
107+
function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
23108
isnothing(cb) || error("""train! does not support callback functions.
24109
For more control use a loop with `gradient` and `update!`.""")
25110
@withprogress for (i,d) in enumerate(data)
26111
d_splat = d isa Tuple ? d : (d,)
27112

28113
_make_zero!(model.dval)
29-
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
114+
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
30115
Active, Const(loss), model, map(Const, d_splat)...)
31116

32117
if !isfinite(l)
@@ -39,4 +124,4 @@ function train!(loss, model::Duplicated, data, opt; cb = nothing)
39124
end
40125
end
41126

42-
end # FluxEnzymeExt
127+
end # FluxEnzymeExt

src/Flux.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@ using MacroTools: @forward
1010
@reexport using NNlib
1111
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
1212
using MLUtils
13+
using Adapt, OneHotArrays
14+
using Functors: Functors, fmap, fmapstructure
1315

1416
using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update!
1517
import Optimisers: trainable
1618
@reexport using Optimisers
1719

1820
using Random: default_rng
21+
1922
using Zygote, ChainRulesCore
20-
using Zygote: @adjoint, gradient, pullback
23+
using Zygote: @adjoint, pullback
2124
using Zygote.ForwardDiff: value
22-
export gradient
25+
using EnzymeCore: EnzymeCore
2326

2427
@reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!,
2528
default_device_rng,
@@ -53,11 +56,12 @@ export Chain, Dense, Embedding, EmbeddingBag,
5356
# utils
5457
outputsize, state, create_bias, @layer,
5558
# from OneHotArrays.jl
56-
onehot, onehotbatch, onecold,
59+
onehot, onehotbatch, onecold,
5760
# from Train
5861
setup, train!,
5962
# from Optimsers.jl
6063
destructure, freeze!, thaw!, adjust!, trainables, update!, trainable,
64+
withgradient,
6165
# init
6266
glorot_uniform,
6367
glorot_normal,
@@ -89,13 +93,13 @@ export Chain, Dense, Embedding, EmbeddingBag,
8993
tversky_loss,
9094
))
9195

96+
include("gradient.jl")
97+
export gradient
98+
9299
include("train.jl")
93100
using .Train
94101
using .Train: setup
95102

96-
using Adapt, OneHotArrays
97-
using Functors: Functors, fmap, fmapstructure
98-
99103
include("utils.jl")
100104
include("functor.jl")
101105

0 commit comments

Comments
 (0)