@@ -38,7 +38,7 @@ to adjust the model:
38
38
39
39
``` julia
40
40
41
- using Flux, Metalhead, Optimisers
41
+ using Flux, Metalhead, Zygote, Optimisers
42
42
43
43
model = Metalhead. ResNet (18 ) |> gpu # define a model to train
44
44
image = rand (Float32, 224 , 224 , 3 , 1 ) |> gpu; # dummy data
@@ -72,14 +72,29 @@ but is free to mutate arrays within the old one for efficiency.
72
72
The method of ` apply! ` you write is likewise free to mutate arrays within its state;
73
73
they are defensively copied when this rule is used with ` update ` .
74
74
75
+ ## Usage with [ Yota.jl] ( https://github.com/dfdx/Yota.jl )
76
+
77
+ Yota is another modern automatic differentiation package, an alternative to Zygote.
78
+
79
+ Its main function is ` Yota.grad ` , which returns the loss as well as the gradient (like ` Zygote.withgradient ` )
80
+ but also returns a gradient component for the loss function.
81
+ To extract what Optimisers.jl needs, you can write ` _, (_, ∇model) = Yota.grad(f, model, data) `
82
+ or, for the Flux model above:
83
+
84
+ ``` julia
85
+ loss, (∇function , ∇model, ∇image) = Yota. grad (model, image) do m, x
86
+ sum (m (x))
87
+ end ;
88
+ ```
89
+
75
90
## Usage with [ Lux.jl] ( https://github.com/avik-pal/Lux.jl )
76
91
77
92
The main design difference of Lux is that the tree of parameters is separate from
78
93
the layer structure. It is these parameters which ` setup ` and ` update ` need to know about.
79
94
80
95
Lux describes this separation of parameter storage from model description as "explicit" parameters.
81
96
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82
- (If the same model is written in Flux and Lux, ` ∇model ` above and ` ∇params ` below will often be
97
+ (If the same model is written in Flux and Lux, ` ∇model ` above and ` ∇params ` below will be nearly
83
98
identical trees of nested ` NamedTuple ` s.)
84
99
85
100
``` julia
@@ -88,27 +103,37 @@ using Lux, Boltz, Zygote, Optimisers
88
103
89
104
lux_model, params, lux_state = Boltz. resnet (:resnet18 ) |> gpu; # define and initialise model
90
105
images = rand (Float32, 224 , 224 , 3 , 4 ) |> gpu; # batch of dummy data
91
- y, _ = Lux. apply (lux_model, images, params, lux_state); # run the model
106
+ y, lux_state = Lux. apply (lux_model, images, params, lux_state); # run the model
92
107
@show sum (y) # initial dummy loss
93
108
94
109
rule = Optimisers. Adam ()
95
110
opt_state = Optimisers. setup (rule, params); # optimiser state based on model parameters
96
111
97
- ∇params, _ = gradient (params, images) do p, x # gradient with respect to parameter tree
98
- y, _ = Lux. apply (lux_model, x, p, lux_state)
99
- sum (y)
100
- end ;
112
+ (loss, lux_state), back = Zygote. pullback (params, images) do p, x
113
+ y, st = Lux. apply (lux_model, x, p, lux_state)
114
+ sum (y), st # return both the loss, and the updated lux_state
115
+ end
116
+ ∇params, _ = back ((one .(loss), nothing )) # gradient of only the loss, with respect to parameter tree
101
117
102
- opt_state, params = Optimisers . update! (opt_state, params, ∇params);
118
+ @show sum (loss)
103
119
104
- y, _ = Lux. apply (lux_model, images, params, lux_state);
105
- @show sum (y)
120
+ opt_state, params = Optimisers. update! (opt_state, params, ∇params);
106
121
107
122
```
108
123
109
124
Besides the parameters stored in ` params ` and gradually optimised, any other model state
110
- is stored in ` lux_state ` . For simplicity this example does not show how to propagate the
111
- updated ` lux_state ` to the next iteration, see Lux's documentation.
125
+ is stored in ` lux_state ` , and returned by ` Lux.apply ` .
126
+ This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
127
+ If you are certain there is no model state, then the gradient calculation can
128
+ be simplified to use ` Zygote.gradient ` instead of ` Zygote.pullback ` :
129
+
130
+ ``` julia
131
+ ∇params, _ = gradient (params, images) do p, x
132
+ y, _ = Lux. apply (lux_model, x, p, lux_state) # discards new lux_state
133
+ sum (y)
134
+ end ;
135
+ ```
136
+
112
137
113
138
## Obtaining a flat parameter vector
114
139
0 commit comments