Skip to content

Commit 79269be

Browse files
authored
Test with Yota, too (#105)
* test with Yota too, and document this * also test destructure * actually try out the doc examples * tidy, add summarysize * add again changes made on website which got lost in a local rebase without checking first because I forgot about this for ages * Yota 0.8.2, etc * skip Yota tests on 1.9 & later * skip more tests
1 parent acc9b16 commit 79269be

File tree

5 files changed

+142
-17
lines changed

5 files changed

+142
-17
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1313
[compat]
1414
ChainRulesCore = "1"
1515
Functors = "0.3, 0.4"
16+
Yota = "0.8.2"
1617
Zygote = "0.6.40"
1718
julia = "1.6"
1819

1920
[extras]
2021
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
2224
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2325

2426
[targets]
25-
test = ["Test", "StaticArrays", "Zygote"]
27+
test = ["Test", "StaticArrays", "Yota", "Zygote"]

docs/src/index.md

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ to adjust the model:
3838

3939
```julia
4040

41-
using Flux, Metalhead, Optimisers
41+
using Flux, Metalhead, Zygote, Optimisers
4242

4343
model = Metalhead.ResNet(18) |> gpu # define a model to train
4444
image = rand(Float32, 224, 224, 3, 1) |> gpu; # dummy data
@@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
5252
end;
5353

5454
state, model = Optimisers.update(state, model, ∇model);
55-
@show sum(model(image));
55+
@show sum(model(image)); # reduced
5656

5757
```
5858

@@ -62,8 +62,14 @@ tree formed by the model and update the parameters using the gradients.
6262

6363
There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
6464
but is free to mutate arrays within the old one for efficiency.
65-
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
66-
they are defensively copied when this rule is used with `update`.
65+
(The method of `apply!` above is likewise free to mutate arrays within its state;
66+
they are defensively copied when this rule is used with `update`.)
67+
For `Adam()`, there are two momenta per parameter, thus `state` is about twice the size of `model`:
68+
69+
```julia
70+
Base.summarysize(model) / 1024^2 # about 45MB
71+
Base.summarysize(state) / 1024^2 # about 90MB
72+
```
6773

6874
Optimisers.jl does not depend on any one automatic differentiation package,
6975
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
@@ -72,14 +78,34 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
7278
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
7379
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
7480

81+
82+
## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)
83+
84+
Yota is another modern automatic differentiation package, an alternative to Zygote.
85+
86+
Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
87+
but also returns a gradient component for the loss function.
88+
To extract what Optimisers.jl needs, you can write (for the Flux model above):
89+
90+
```julia
91+
using Yota
92+
93+
loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
94+
sum(m(x)
95+
end;
96+
97+
# Or else, this may save computing ∇image:
98+
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
99+
```
100+
75101
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
76102
77-
The main design difference of Lux is that the tree of parameters is separate from
103+
The main design difference of Lux from Flux is that the tree of parameters is separate from
78104
the layer structure. It is these parameters which `setup` and `update` need to know about.
79105
80106
Lux describes this separation of parameter storage from model description as "explicit" parameters.
81107
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82-
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will often be
108+
(If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly
83109
identical trees of nested `NamedTuple`s.)
84110
85111
```julia
@@ -88,27 +114,47 @@ using Lux, Boltz, Zygote, Optimisers
88114

89115
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
90116
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
91-
y, _ = Lux.apply(lux_model, images, params, lux_state); # run the model
92-
@show sum(y) # initial dummy loss
117+
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
118+
@show sum(y); # initial dummy loss
93119

94120
rule = Optimisers.Adam()
95121
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters
96122

97-
∇params, _ = gradient(params, images) do p, x # gradient with respect to parameter tree
98-
y, _ = Lux.apply(lux_model, x, p, lux_state)
99-
sum(y)
123+
(loss, lux_state), back = Zygote.pullback(params, images) do p, x
124+
y, st = Lux.apply(lux_model, x, p, lux_state)
125+
sum(y), st # return both the loss, and the updated lux_state
100126
end;
127+
∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree
128+
loss == sum(y) # not yet changed
101129

102130
opt_state, params = Optimisers.update!(opt_state, params, ∇params);
103131

104-
y, _ = Lux.apply(lux_model, images, params, lux_state);
105-
@show sum(y)
132+
y, lux_state = Lux.apply(lux_model, images, params, lux_state);
133+
@show sum(y); # now reduced
106134

107135
```
108136
109137
Besides the parameters stored in `params` and gradually optimised, any other model state
110-
is stored in `lux_state`. For simplicity this example does not show how to propagate the
111-
updated `lux_state` to the next iteration, see Lux's documentation.
138+
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
139+
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
140+
141+
```julia
142+
Base.summarysize(lux_model) / 1024 # just 2KB
143+
Base.summarysize(params) / 1024^2 # about 45MB, same as Flux model
144+
Base.summarysize(lux_state) / 1024 # 40KB
145+
Base.summarysize(opt_state) / 1024^2 # about 90MB, with Adam
146+
```
147+
148+
If you are certain there is no model state, then the gradient calculation can
149+
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
150+
151+
```julia
152+
∇params, _ = gradient(params, images) do p, x
153+
y, _ = Lux.apply(lux_model, x, p, lux_state) # discards new lux_state
154+
sum(y)
155+
end;
156+
```
157+
112158
113159
## Non-`trainable` Parameters
114160

test/destructure.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,31 @@ end
104104
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
105105
# Diffractor error in perform_optic_transform
106106
end
107+
108+
VERSION < v"1.9-" && @testset "using Yota" begin
109+
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
110+
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
111+
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
112+
@test Yota_gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
113+
@test Yota_gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])
114+
115+
g5 = Yota_gradient(m -> destructure(m)[1][3], m5)[1]
116+
@test g5.a[1].x == [0,0,1]
117+
@test g5.a[2] === nothing
118+
119+
g6 = Yota_gradient(m -> imag(destructure(m)[1][4]), m6)[1]
120+
@test g6.a == [0,0,0]
121+
@test g6.a isa Vector{Float64}
122+
@test g6.b == [0+im]
123+
124+
g8 = Yota_gradient(m -> sum(abs2, destructure(m)[1]), m8)[1]
125+
@test g8[1].x == [2,4,6]
126+
@test g8[2].b.x == [8]
127+
@test g8[3] == [[10.0]]
128+
129+
g9 = Yota_gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
130+
@test g9.c === nothing
131+
end
107132
end
108133

109134
@testset "gradient of rebuild" begin
@@ -149,6 +174,36 @@ end
149174
# Not fixed by this:
150175
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
151176
end
177+
178+
VERSION < v"1.9-" && @testset "using Yota" begin
179+
re1 = destructure(m1)[2]
180+
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
181+
re2 = destructure(m2)[2]
182+
@test Yota_gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]
183+
re3 = destructure(m3)[2]
184+
@test Yota_gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
185+
@test Yota_gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]
186+
187+
re4 = destructure(m4)[2]
188+
@test Yota_gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
189+
@test Yota_gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
190+
@test Yota_gradient(rand(6)) do x
191+
m = re4(x)
192+
m.x[1] + 2*m.y[2] + 3*m.z[3]
193+
end[1] == [1,2,0, 0,0,3]
194+
195+
re7 = destructure(m7)[2]
196+
@test Yota_gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1]
197+
@test Yota_gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
198+
@test Yota_gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]
199+
200+
v8, re8 = destructure(m8)
201+
@test Yota_gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
202+
@test Yota_gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
203+
204+
re9 = destructure(m9)[2]
205+
@test Yota_gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
206+
end
152207
end
153208

154209
@testset "Flux issue 1826" begin

test/rules.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,18 @@ end
229229
@test static_loss(static_model) < 1.9
230230
end
231231
end
232+
233+
VERSION < v"1.9-" && @testset "using Yota" begin
234+
@testset "$(name(o))" for o in RULES
235+
w′ = (abc == rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
236+
w = (abc == 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
237+
st = Optimisers.setup(o, w)
238+
loss(x, y) = mean((x.abc.α .* x.abc.β .- y.abc.α .* y.abc.β) .^ 2) # does not use γ, δ, ε
239+
@test loss(w, w′) > 0.5
240+
for i = 1:10^4
241+
_, (_, g, _) = Yota.grad(loss, w, w′)
242+
st, w = Optimisers.update(st, w, g)
243+
end
244+
@test loss(w, w′) < 0.001
245+
end
246+
end

test/runtests.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Optimisers
2-
using ChainRulesCore, Functors, StaticArrays, Zygote
2+
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
33
using LinearAlgebra, Statistics, Test, Random
44
using Optimisers: @.., @lazy
55

@@ -37,6 +37,13 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
3737
return state, dx
3838
end
3939

40+
# Make Yota's output look like Zygote's:
41+
42+
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
43+
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
44+
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
45+
y2z(x) = x
46+
4047
@testset verbose=true "Optimisers.jl" begin
4148
@testset verbose=true "Features" begin
4249

0 commit comments

Comments
 (0)