Skip to content

Commit 5443900

Browse files
committed
test with Yota too, and document this
1 parent 7d0c939 commit 5443900

File tree

4 files changed

+56
-14
lines changed

4 files changed

+56
-14
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
ChainRulesCore = "1"
1515
Functors = "0.3"
16+
Yota = "0.7.3"
1617
Zygote = "0.6.40"
1718
julia = "1.6"
1819

1920
[extras]
2021
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
2224
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2325

2426
[targets]
25-
test = ["Test", "StaticArrays", "Zygote"]
27+
test = ["Test", "StaticArrays", "Yota", "Zygote"]

docs/src/index.md

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ to adjust the model:
3838

3939
```julia
4040

41-
using Flux, Metalhead, Optimisers
41+
using Flux, Metalhead, Zygote, Optimisers
4242

4343
model = Metalhead.ResNet(18) |> gpu # define a model to train
4444
image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
@@ -72,14 +72,29 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
7272
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
7373
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
7474

75+
## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)
76+
77+
Yota is another modern automatic differentiation package, an alternative to Zygote.
78+
79+
Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
80+
but also returns a gradient component for the loss function.
81+
To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad(f, model, data)`
82+
or, for the Flux model above:
83+
84+
```julia
85+
loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
86+
sum(m(x))
87+
end;
88+
```
89+
7590
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
7691

7792
The main design difference of Lux is that the tree of parameters is separate from
7893
the layer structure. It is these parameters which `setup` and `update` need to know about.
7994

8095
Lux describes this separation of parameter storage from model description as "explicit" parameters.
8196
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82-
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be
97+
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly
8398
identical trees of nested `NamedTuple`s.)
8499

85100
```julia
@@ -88,27 +103,37 @@ using Lux, Boltz, Zygote, Optimisers
88103

89104
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
90105
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
91-
y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model
106+
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
92107
@show sum(y) # initial dummy loss
93108

94109
rule = Optimisers.Adam()
95110
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters
96111

97-
∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree
98-
y, _ = Lux.apply(lux_model, x, p, lux_state)
99-
sum(y)
100-
end;
112+
(loss, lux_state), back = Zygote.pullback(params, images) do p, x
113+
y, st = Lux.apply(lux_model, x, p, lux_state)
114+
sum(y), st # return both the loss, and the updated lux_state
115+
end
116+
∇params, _ = back((one.(loss), nothing)) # gradient of only the loss, with respect to parameter tree
101117

102-
opt_state, params = Optimisers.update!(opt_state, params, ∇params);
118+
@show sum(loss)
103119

104-
y, _ = Lux.apply(lux_model, images, params, lux_state);
105-
@show sum(y)
120+
opt_state, params = Optimisers.update!(opt_state, params, ∇params);
106121

107122
```
108123

109124
Besides the parameters stored in `params` and gradually optimised, any other model state
110-
is stored in `lux_state`. For simplicity this example does not show how to propagate the
111-
updated `lux_state` to the next iteration, see Lux's documentation.
125+
is stored in `lux_state`, and returned by `Lux.apply`.
126+
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
127+
If you are certain there is no model state, then the gradient calculation can
128+
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
129+
130+
```julia
131+
∇params, _ = gradient(params, images) do p, x
132+
y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state
133+
sum(y)
134+
end;
135+
```
136+
112137

113138
## Non-`trainable` Parameters
114139

test/rules.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,18 @@ end
229229
@test static_loss(static_model) < 1.9
230230
end
231231
end
232+
233+
@testset "using Yota" begin
234+
@testset "$(name(o))" for o in RULES
235+
w′ = (abc == rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
236+
w = (abc == 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
237+
st = Optimisers.setup(o, w)
238+
loss(x, y) = mean((x.abc.α .* x.abc.β .- y.abc.α .* y.abc.β) .^ 2) # does not use γ, δ, ε
239+
@test loss(w, w′) > 0.5
240+
for i = 1:10^4
241+
_, (_, g, _) = Yota.grad(loss, w, w′)
242+
st, w = Optimisers.update(st, w, g)
243+
end
244+
@test loss(w, w′) < 0.001
245+
end
246+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Optimisers
2-
using ChainRulesCore, Functors, StaticArrays, Zygote
2+
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
33
using LinearAlgebra, Statistics, Test, Random
44
using Optimisers: @.., @lazy
55

0 commit comments

Comments
 (0)