@@ -38,7 +38,7 @@ to adjust the model:
38
38
39
39
``` julia
40
40
41
- using Flux, Metalhead, Optimisers
41
+ using Flux, Metalhead, Zygote, Optimisers
42
42
43
43
model = Metalhead. ResNet (18 ) |> gpu # define a model to train
44
44
image = rand (Float32, 224 , 224 , 3 , 1 ) |> gpu; # dummy data
@@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
52
52
end ;
53
53
54
54
state, model = Optimisers. update (state, model, ∇model);
55
- @show sum (model (image));
55
+ @show sum (model (image)); # reduced
56
56
57
57
```
58
58
@@ -62,8 +62,14 @@ tree formed by the model and update the parameters using the gradients.
62
62
63
63
There is also [ ` Optimisers.update! ` ] ( @ref ) which similarly returns a new model and new state,
64
64
but is free to mutate arrays within the old one for efficiency.
65
- The method of ` apply! ` for each rule is likewise free to mutate arrays within its state;
66
- they are defensively copied when this rule is used with ` update ` .
65
+ (The method of ` apply! ` above is likewise free to mutate arrays within its state;
66
+ they are defensively copied when this rule is used with ` update ` .)
67
+ For ` Adam() ` , there are two momenta per parameter, thus ` state ` is about twice the size of ` model ` :
68
+
69
+ ``` julia
70
+ Base. summarysize (model) / 1024 ^ 2 # about 45MB
71
+ Base. summarysize (state) / 1024 ^ 2 # about 90MB
72
+ ```
67
73
68
74
Optimisers.jl does not depend on any one automatic differentiation package,
69
75
but for now the most likely source of gradients is [ Zygote.jl] ( https://fluxml.ai/Zygote.jl ) .
@@ -72,14 +78,34 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
72
78
Zygote's "implicit" mode ` gradient(() -> loss(...), Flux.params(model)) ` -- see
73
79
[ Zygote's documentation] ( https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1 ) for more about this difference.
74
80
81
+
82
+ ## Usage with [ Yota.jl] ( https://github.com/dfdx/Yota.jl )
83
+
84
+ Yota is another modern automatic differentiation package, an alternative to Zygote.
85
+
86
+ Its main function is ` Yota.grad ` , which returns the loss as well as the gradient (like ` Zygote.withgradient ` )
87
+ but also returns a gradient component for the loss function.
88
+ To extract what Optimisers.jl needs, you can write (for the Flux model above):
89
+
90
+ ``` julia
91
+ using Yota
92
+
93
+ loss, (∇function , ∇model, ∇image) = Yota. grad (model, image) do m, x
94
+ sum (m (x)
95
+ end ;
96
+
97
+ # Or else, this may save computing ∇image:
98
+ loss, (_, ∇model) = grad (m -> sum (m (image)), model);
99
+ ```
100
+
75
101
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
76
102
77
- The main design difference of Lux is that the tree of parameters is separate from
103
+ The main design difference of Lux from Flux is that the tree of parameters is separate from
78
104
the layer structure. It is these parameters which `setup` and `update` need to know about.
79
105
80
106
Lux describes this separation of parameter storage from model description as "explicit" parameters.
81
107
Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82
- (If the same model is written in Flux and Lux, ` ∇model ` above and ` ∇params ` below will often be
108
+ (If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly
83
109
identical trees of nested `NamedTuple`s.)
84
110
85
111
``` julia
@@ -88,27 +114,47 @@ using Lux, Boltz, Zygote, Optimisers
88
114
89
115
lux_model, params, lux_state = Boltz. resnet (:resnet18 ) |> gpu; # define and initialise model
90
116
images = rand (Float32, 224 , 224 , 3 , 4 ) |> gpu; # batch of dummy data
91
- y, _ = Lux. apply (lux_model, images, params, lux_state); # run the model
92
- @show sum (y) # initial dummy loss
117
+ y, lux_state = Lux. apply (lux_model, images, params, lux_state); # run the model
118
+ @show sum (y); # initial dummy loss
93
119
94
120
rule = Optimisers. Adam ()
95
121
opt_state = Optimisers. setup (rule, params); # optimiser state based on model parameters
96
122
97
- ∇params, _ = gradient (params, images) do p, x # gradient with respect to parameter tree
98
- y, _ = Lux. apply (lux_model, x, p, lux_state)
99
- sum (y)
123
+ (loss, lux_state), back = Zygote . pullback (params, images) do p, x
124
+ y, st = Lux. apply (lux_model, x, p, lux_state)
125
+ sum (y), st # return both the loss, and the updated lux_state
100
126
end ;
127
+ ∇params, _ = back ((one .(loss), nothing )); # gradient of only the loss, with respect to parameter tree
128
+ loss == sum (y) # not yet changed
101
129
102
130
opt_state, params = Optimisers. update! (opt_state, params, ∇params);
103
131
104
- y, _ = Lux. apply (lux_model, images, params, lux_state);
105
- @show sum (y)
132
+ y, lux_state = Lux. apply (lux_model, images, params, lux_state);
133
+ @show sum (y); # now reduced
106
134
107
135
```
108
136
109
137
Besides the parameters stored in `params` and gradually optimised, any other model state
110
- is stored in ` lux_state ` . For simplicity this example does not show how to propagate the
111
- updated ` lux_state ` to the next iteration, see Lux's documentation.
138
+ is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
139
+ This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
140
+
141
+ ``` julia
142
+ Base. summarysize (lux_model) / 1024 # just 2KB
143
+ Base. summarysize (params) / 1024 ^ 2 # about 45MB, same as Flux model
144
+ Base. summarysize (lux_state) / 1024 # 40KB
145
+ Base. summarysize (opt_state) / 1024 ^ 2 # about 90MB, with Adam
146
+ ```
147
+
148
+ If you are certain there is no model state, then the gradient calculation can
149
+ be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
150
+
151
+ ``` julia
152
+ ∇params, _ = gradient (params, images) do p, x
153
+ y, _ = Lux. apply (lux_model, x, p, lux_state) # discards new lux_state
154
+ sum (y)
155
+ end ;
156
+ ```
157
+
112
158
113
159
## Non-`trainable` Parameters
114
160
0 commit comments