Skip to content

Commit 02dcfa0

Browse files
committed
Add support for explicit mode gradients and optimizers
1 parent c7ed5fe commit 02dcfa0

File tree

5 files changed

+60
-15
lines changed

5 files changed

+60
-15
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.13.6"
3+
version = "0.13.7-DEV"
44

55
[deps]
6+
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
67
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
89
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -26,6 +27,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2728

2829
[compat]
30+
AbstractDifferentiation = "0.4.3"
2931
Adapt = "3.0"
3032
ArrayInterface = "3.1, 4, 5, 6"
3133
CUDA = "3"

src/optimise/Optimise.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
module Optimise
22

3+
using Flux
4+
using MacroTools: @forward
5+
import Zygote
6+
import Zygote: Params, gradient
7+
using AbstractDifferentiation
8+
import Optimisers
9+
import Optimisers: update, update!
310
using LinearAlgebra
411
import ArrayInterface
12+
using ProgressLogging: @progress, @withprogress, @logprogress
513

614
export train!, update!,
715
Descent, Adam, Momentum, Nesterov, RMSProp,
@@ -10,6 +18,7 @@ export train!, update!,
1018
ClipValue, ClipNorm
1119

1220
include("optimisers.jl")
21+
include("gradients.jl")
1322
include("train.jl")
1423

1524
end

src/optimise/gradients.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
struct ZygoteImplicitBackend{T} <: AD.AbstractReverseMode
2+
core_backend::T
3+
end
4+
ZygoteImplicitBackend() = ZygoteImplicitBackend(AD.ZygoteBackend())
5+
6+
AD.@primitive pullback_function(ad::ZygoteImplicitBackend, f, x::Zygote.Params) =
7+
AD.pullback_function(ad.core_backend, f, x)
8+
9+
# this is a hack to get around
10+
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
11+
AD.gradient(::ZygoteImplicitBackend, f, x::Zygote.Params) = Zygote.gradient(f, x)
12+
13+
struct ZygoteExplicitBackend{T} <: AD.AbstractReverseMode
14+
core_backend::T
15+
end
16+
ZygoteExplicitBackend() = ZygoteExplicitBackend(AD.ZygoteBackend())
17+
18+
AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) =
19+
AD.pullback_function(ad.core_backend, f, xs...)
20+
21+
# this is a hack to get around
22+
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
23+
AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...)

src/optimise/optimisers.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
using Flux
2-
using MacroTools: @forward
3-
41
abstract type AbstractOptimiser end
52

63
const EPS = 1e-8

src/optimise/train.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
using ProgressLogging: @progress, @withprogress, @logprogress
2-
import Zygote: Params, gradient
3-
4-
51
"""
62
update!(opt, p, g)
73
update!(opt, ps::Params, gs)
@@ -12,18 +8,23 @@ according to optimizer `opt` and the gradients `gs` (the gradient `g`).
128
As a result, the parameters are mutated and the optimizer's internal state may change.
139
The gradient could be mutated as well.
1410
"""
15-
function update!(opt::AbstractOptimiser, x, x̄)
11+
function Optimisers.update!(opt::AbstractOptimiser, x, x̄)
1612
x̄r = ArrayInterface.restructure(x, x̄) # address some cases where Zygote's
1713
# output are not mutable, see #1510
1814
x .-= apply!(opt, x, x̄r)
15+
16+
return opt, x
1917
end
2018

21-
function update!(opt::AbstractOptimiser, xs::Params, gs)
19+
function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs)
2220
for x in xs
2321
isnothing(gs[x]) && continue
2422
update!(opt, x, gs[x])
2523
end
24+
25+
return opt, xs
2626
end
27+
Optimisers.update(opt::AbstractOptimiser, xs::Params, gs) = update!(opt, xs, gs)
2728

2829
# Callback niceties
2930
call(f, xs...) = f(xs...)
@@ -82,6 +83,16 @@ end
8283
batchmemaybe(x) = tuple(x)
8384
batchmemaybe(x::Tuple) = x
8485

86+
_build_loss(::AD.AbstractBackend, loss, data) = function _loss(m)
87+
return loss(m, data...)
88+
end
89+
_build_loss(::ZygoteImplicitBackend, loss, data) = function _loss()
90+
return loss(data...)
91+
end
92+
_gradient_only(x::Zygote.Grads) = x
93+
_gradient_only(x::NTuple{1}) = x[1]
94+
_gradient_only(x) = error("Expected gradient w.r.t. single argument (or Zygote.Grads) but got $x")
95+
8596
"""
8697
train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb])
8798
@@ -120,16 +131,15 @@ The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
120131
121132
Multiple callbacks can be passed to `cb` as array.
122133
"""
123-
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
134+
function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () -> ())
124135
cb = runall(cb)
125136
itrsz = Base.IteratorSize(typeof(data))
126137
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
127138
@withprogress for (i, d) in enumerate(data)
128139
try
129-
gs = gradient(ps) do
130-
loss(batchmemaybe(d)...)
131-
end
132-
update!(opt, ps, gs)
140+
_loss = _build_loss(ad, loss, batchmemaybe(d))
141+
gs = _gradient_only(AD.gradient(ad, _loss, model))
142+
optstate, model = update(optstate, model, gs)
133143
cb()
134144
catch ex
135145
if ex isa StopException
@@ -142,7 +152,11 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
142152
end
143153
@logprogress iszero(n) ? nothing : i / n
144154
end
155+
156+
return optstate, model
145157
end
158+
train!(loss, model, data, optstate; kwargs...) =
159+
train!(loss, ZygoteImplicitBackend(), model, data, optstate; kwargs...)
146160

147161
"""
148162
@epochs N body

0 commit comments

Comments
 (0)