@@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
52
52
end ;
53
53
54
54
state, model = Optimisers. update (state, model, ∇model);
55
- @show sum (model (image));
55
+ @show sum (model (image)); # reduced
56
56
57
57
```
58
58
@@ -82,14 +82,51 @@ To extract what Optimisers.jl needs, you can write `_, (_, ∇model) = Yota.grad
82
82
or, for the Flux model above:
83
83
84
84
``` julia
85
+ using Yota
86
+
85
87
loss, (∇function , ∇model, ∇image) = Yota. grad (model, image) do m, x
86
88
sum (m (x))
87
89
end ;
88
90
```
89
91
92
+ Unfortunately this example doesn't actually run right now. This is the error:
93
+ ```
94
+ julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95
+ sum(m(x))
96
+ end;
97
+ ┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
98
+ │ (f, args) = Yota.RRULE_VIA_AD_STATE[]
99
+ └ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
100
+ ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
101
+
102
+ ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
103
+
104
+ Stacktrace:
105
+ [1] error(s::String)
106
+ @ Base ./error.jl:35
107
+ [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
108
+ @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
109
+ [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
110
+ @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
111
+ [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
112
+ @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
113
+ [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
114
+ @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
115
+ [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
116
+ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
117
+ [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
118
+ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
119
+ ...
120
+
121
+ (jl_GWa2lX) pkg> st
122
+ Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
123
+ ⌃ [587475ba] Flux v0.13.4
124
+ [cd998857] Yota v0.7.4
125
+ ```
126
+
90
127
## Usage with [ Lux.jl] ( https://github.com/avik-pal/Lux.jl )
91
128
92
- The main design difference of Lux is that the tree of parameters is separate from
129
+ The main design difference of Lux from Flux is that the tree of parameters is separate from
93
130
the layer structure. It is these parameters which ` setup ` and ` update ` need to know about.
94
131
95
132
Lux describes this separation of parameter storage from model description as "explicit" parameters.
@@ -104,25 +141,27 @@ using Lux, Boltz, Zygote, Optimisers
104
141
lux_model, params, lux_state = Boltz. resnet (:resnet18 ) |> gpu; # define and initialise model
105
142
images = rand (Float32, 224 , 224 , 3 , 4 ) |> gpu; # batch of dummy data
106
143
y, lux_state = Lux. apply (lux_model, images, params, lux_state); # run the model
107
- @show sum (y) # initial dummy loss
144
+ @show sum (y); # initial dummy loss
108
145
109
146
rule = Optimisers. Adam ()
110
147
opt_state = Optimisers. setup (rule, params); # optimiser state based on model parameters
111
148
112
149
(loss, lux_state), back = Zygote. pullback (params, images) do p, x
113
150
y, st = Lux. apply (lux_model, x, p, lux_state)
114
151
sum (y), st # return both the loss, and the updated lux_state
115
- end
116
- ∇params, _ = back ((one .(loss), nothing )) # gradient of only the loss, with respect to parameter tree
117
-
118
- @show sum (loss)
152
+ end ;
153
+ ∇params, _ = back ((one .(loss), nothing )); # gradient of only the loss, with respect to parameter tree
154
+ loss == sum (y) # not yet changed
119
155
120
156
opt_state, params = Optimisers. update! (opt_state, params, ∇params);
121
157
158
+ y, lux_state = Lux. apply (lux_model, images, params, lux_state);
159
+ @show sum (y); # now reduced
160
+
122
161
```
123
162
124
163
Besides the parameters stored in ` params ` and gradually optimised, any other model state
125
- is stored in ` lux_state ` , and returned by ` Lux.apply ` .
164
+ is stored in ` lux_state ` , and updated by ` Lux.apply ` . (In this example, BatchNorm has state.)
126
165
This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
127
166
If you are certain there is no model state, then the gradient calculation can
128
167
be simplified to use ` Zygote.gradient ` instead of ` Zygote.pullback ` :
0 commit comments