Skip to content

Commit a6e85fc

Browse files
committed
actually try out the doc examples
1 parent 4244b83 commit a6e85fc

File tree

1 file changed

+47
-8
lines changed

1 file changed

+47
-8
lines changed

docs/src/index.md

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
5252
end;
5353

5454
state, model = Optimisers.update(state, model, ∇model);
55-
@show sum(model(image));
55+
@show sum(model(image)); # reduced
5656

5757
```
5858

@@ -82,14 +82,51 @@ To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad
8282
or, for the Flux model above:
8383

8484
```julia
85+
using Yota
86+
8587
loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
8688
sum(m(x))
8789
end;
8890
```
8991

92+
Unfortunately this example doesn't actually run right now. This is the error:
93+
```
94+
julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95+
sum(m(x))
96+
end;
97+
┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
98+
│ (f, args) = Yota.RRULE_VIA_AD_STATE[]
99+
└ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
100+
ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
101+
102+
ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
103+
104+
Stacktrace:
105+
[1] error(s::String)
106+
@ Base ./error.jl:35
107+
[2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
108+
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
109+
[3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
110+
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
111+
[4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
112+
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
113+
[5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
114+
@ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
115+
[6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
116+
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
117+
[7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
118+
@ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
119+
...
120+
121+
(jl_GWa2lX) pkg> st
122+
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
123+
⌃ [587475ba] Flux v0.13.4
124+
[cd998857] Yota v0.7.4
125+
```
126+
90127
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
91128

92-
The main design difference of Lux is that the tree of parameters is separate from
129+
The main design difference of Lux from Flux is that the tree of parameters is separate from
93130
the layer structure. It is these parameters which `setup` and `update` need to know about.
94131

95132
Lux describes this separation of parameter storage from model description as "explicit" parameters.
@@ -104,25 +141,27 @@ using Lux, Boltz, Zygote, Optimisers
104141
lux_model, params, lux_state = Boltz.resnet(:resnet18) |> gpu; # define and initialise model
105142
images = rand(Float32, 224, 224, 3, 4) |> gpu; # batch of dummy data
106143
y, lux_state = Lux.apply(lux_model, images, params, lux_state); # run the model
107-
@show sum(y) # initial dummy loss
144+
@show sum(y); # initial dummy loss
108145

109146
rule = Optimisers.Adam()
110147
opt_state = Optimisers.setup(rule, params); # optimiser state based on model parameters
111148

112149
(loss, lux_state), back = Zygote.pullback(params, images) do p, x
113150
y, st = Lux.apply(lux_model, x, p, lux_state)
114151
sum(y), st # return both the loss, and the updated lux_state
115-
end
116-
∇params, _ = back((one.(loss), nothing)) # gradient of only the loss, with respect to parameter tree
117-
118-
@show sum(loss)
152+
end;
153+
∇params, _ = back((one.(loss), nothing)); # gradient of only the loss, with respect to parameter tree
154+
loss == sum(y) # not yet changed
119155

120156
opt_state, params = Optimisers.update!(opt_state, params, ∇params);
121157

158+
y, lux_state = Lux.apply(lux_model, images, params, lux_state);
159+
@show sum(y); # now reduced
160+
122161
```
123162

124163
Besides the parameters stored in `params` and gradually optimised, any other model state
125-
is stored in `lux_state`, and returned by `Lux.apply`.
164+
is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
126165
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
127166
If you are certain there is no model state, then the gradient calculation can
128167
be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:

0 commit comments

Comments
 (0)