Skip to content

Commit faf572f

Browse files
committed
add train_autodiff macro
1 parent 7aa74e0 commit faf572f

File tree

4 files changed

+100
-4
lines changed

4 files changed

+100
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ ProgressLogging = "0.1"
3939
Reexport = "0.2, 1.0"
4040
SpecialFunctions = "1.8.2, 2.1.2"
4141
StatsBase = "0.33"
42+
Yota = "0.7.4"
4243
Zygote = "0.6.34"
4344
julia = "1.6"
4445

@@ -49,6 +50,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4950
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
5051
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5152
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
53+
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
5254

5355
[targets]
54-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
56+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Yota"]

src/train/Train.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LinearAlgebra
44
using Optimisers: Optimisers
55
using Functors: fmap
66

7-
export train!, update!, adjust!, FluxState,
7+
export train!, update!, adjust!, FluxState, @train_autodiff,
88
Descent, Adam, Momentum, Nesterov, RMSProp,
99
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #,
1010
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -108,6 +108,52 @@ include("implicit_train.jl") # Params etc, Zygote only
108108

109109
explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor
110110

111+
"""
112+
Flux.@train_autodiff Zygote
113+
Flux.@train_autodiff Yota
114+
Flux.@train_autodiff Diffractor
115+
116+
This macro allows the use of `train!` with various automatic differentiation packages,
117+
instead of the default Zygote.jl.
118+
You should load the package, then call this macro.
119+
120+
Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`,
121+
since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific.
122+
123+
Only works with [Yota.jl](https://github.com/dfdx/Yota.jl) and [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl),
124+
and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl).
125+
126+
!!! note
127+
This is experimental!
128+
"""
129+
macro train_autodiff(pkg)
130+
if pkg == :Diffractor
131+
return quote
132+
Diffractor.gradient(sin, 0.0)[1] 1.0 # ensures an error if not loaded
133+
function Flux.Train.explicit_withgradient(f, args...)
134+
y, back = Diffractor.∂⃖¹(f, args...)
135+
dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors
136+
return (; value = y, gradient = Base.tail(back(dy1)))
137+
end
138+
end |> esc
139+
elseif pkg == :Yota
140+
return quote
141+
Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
142+
function Flux.Train.explicit_withgradient(f, args...)
143+
value, (_, gradient...) = Yota.grad(f, args...)
144+
return (; value, gradient)
145+
end
146+
end |> esc
147+
elseif pkg == :Zygote
148+
return quote
149+
Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
150+
end |> esc
151+
else
152+
throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.")
153+
end
154+
end
155+
156+
111157
### Misc. related utilities
112158

113159
"""

src/train/explicit_train.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ end
2525
which evaluates the gradient of `loss(model, x1, y1)` with respect to `model`,
2626
to know how to update the parameters stored within `model`.
2727
28+
To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota`
29+
to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl).
30+
2831
It is often convenient to provide the function `loss` using `do` block syntax,
2932
instead of defining a named function:
3033
```
@@ -100,6 +103,9 @@ This calls `Zygote.withgradient(m -> Flux.crossentropy(m(x1), y1), model)`.
100103
Then it updates the parameters contained within `model` according
101104
to the chosen `opt`imiser.
102105
Finally it returns the value of the loss function.
106+
107+
To change the package used to calculate gradients, enter `using Yota; Flux.@train_autodiff Yota`
108+
to use [Yota.jl](https://github.com/dfdx/Yota.jl). The same command works with [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl).
103109
"""
104110
function train!(loss::Function, model, opt::FluxState)
105111
_initialise!(opt, model)

test/train.jl

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Flux.Train
22
using Zygote: Params, gradient
33

4-
import Optimisers, FillArrays, ComponentArrays
4+
import Optimisers, FillArrays, ComponentArrays, Yota
55

66
using Test
77
using Random
@@ -22,7 +22,7 @@ using Random
2222
end
2323
end
2424

25-
@testset "Explicit train!" begin
25+
@testset "Explicit train! with Zygote" begin
2626
Random.seed!(84)
2727
w = randn(10, 10)
2828
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@@ -70,6 +70,48 @@ end
7070
end
7171
end
7272

73+
using Yota
74+
using Flux: Descent, Adam, AdamW, FluxState
75+
Flux.@train_autodiff Yota
76+
77+
@testset "Explicit train! with Yota" begin
78+
Random.seed!(84)
79+
w = randn(10, 10)
80+
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
81+
@testset for opt in [Descent(0.1), Adam(), AdamW()]
82+
@test opt isa FluxState
83+
@test opt.state isa Missing
84+
85+
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
86+
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
87+
@test loss(model, rand(10, 10)) > 1
88+
89+
train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
90+
@test loss(model, rand(10, 10)) < 0.01
91+
@test opt.state isa NamedTuple
92+
end
93+
94+
# Test 3-arg `train!` method:
95+
@testset for opt in [Descent(0.1), Adam(), AdamW()]
96+
@test opt isa FluxState
97+
@test opt.state isa Missing
98+
99+
loss(m) = let x = rand(10)
100+
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
101+
end
102+
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
103+
@test loss(model) > 1
104+
105+
for i in 1:10^5
106+
train!(loss, model, opt)
107+
end
108+
@test loss(model) < 0.01
109+
@test opt.state isa NamedTuple
110+
end
111+
end
112+
113+
Flux.@train_autodiff Zygote
114+
73115
#=
74116
75117
@testset "update!: handle Fills from Zygote" begin

0 commit comments

Comments
 (0)