Skip to content

Commit 4764879

Browse files
committed
add draft of gradient zoo page
1 parent f5d25e5 commit 4764879

File tree

1 file changed

+330
-0
lines changed

1 file changed

+330
-0
lines changed

docs/src/tutorials/gradient_zoo.md

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
# The Gradient Zoo
2+
3+
The heart of how deep learning works is backpropagation of error,
4+
also known as reverse-mode automatic differentiation.
5+
Given a model, some data, and a loss function, this answers the question
6+
"what direction, in the space of the model's parameters, reduces the loss fastest?"
7+
8+
Julia's ecosystem has many versions of `gradient(f, x)`, which evaluates `y = f(x)` then retuns `∂y_∂x`. The details of how they do this vary, but the interfece is similar. An incomplete list is (alphabetically):
9+
10+
```julia
11+
julia> Diffractor.gradient(x -> sum(sqrt, x), [1 4 16.])
12+
([0.5 0.25 0.125],)
13+
14+
julia> Enzyme.gradient(Reverse, x -> sum(sqrt, x), [1 4 16.])
15+
1×3 Matrix{Float64}:
16+
0.5 0.25 0.125
17+
18+
julia> ForwardDiff.gradient(x -> sum(sqrt, x), [1 4 16.])
19+
1×3 Matrix{Float64}:
20+
0.5 0.25 0.125
21+
22+
julia> ReverseDiff.gradient(x -> sum(sqrt, x), [1 4 16.])
23+
1×3 Matrix{Float64}:
24+
0.5 0.25 0.125
25+
26+
julia> DifferentiationInterface.gradient(x -> sum(sqrt, x), AutoTapir(), [1 4 16.])
27+
1×3 Matrix{Float64}:
28+
0.5 0.25 0.125
29+
30+
julia> Tracker.gradient(x -> sum(sqrt, x), [1 4 16.])
31+
([0.5 0.25 0.125] (tracked),)
32+
33+
julia> Yota.grad(x -> sum(sqrt, x), [1 4 16.])
34+
(7.0, (ChainRulesCore.ZeroTangent(), [0.5 0.25 0.125]))
35+
36+
julia> Zygote.withgradient(x -> sum(sqrt, x), [1 4 16.])
37+
(val = 7.0, grad = ([0.5 0.25 0.125],))
38+
```
39+
40+
These all show the same `∂y_∂x` with respect to `x::Vector`. Sometimes, the result is within a tuple or a NamedTuple.
41+
42+
However, the parameters of a Flux model are encapsulated inside the various layers. The model is a set of nested structures. And the gradients `∂loss_∂model` which Flux uses are similarly nested objects.
43+
For example, let's set up a simple model & loss:
44+
45+
```julia
46+
julia> model = Chain(Embedding(reshape(1:6, 2,3) .+ 0.0), softmax)
47+
Chain(
48+
Embedding(3 => 2), # 6 parameters
49+
NNlib.softmax,
50+
)
51+
52+
julia> model.layers[1].weight # this is the wrapped parameter array
53+
2×3 Matrix{Float64}:
54+
1.0 3.0 5.0
55+
2.0 4.0 6.0
56+
57+
julia> loss(m) = sum(abs2, m(1))
58+
loss (generic function with 3 methods)
59+
60+
julia> loss(model) # returns a number
61+
0.6067761335170363
62+
```
63+
64+
Then we can find the same gradient using several packages:
65+
66+
```julia
67+
julia> val, grads_z = Zygote.withgradient(loss, model)
68+
(val = 0.6067761335170363, grad = ((layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),),))
69+
70+
julia> _, grads_t = Tracker.withgradient(loss, model)
71+
(val = 0.6067761335170363, grad = ((layers = ((weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), nothing),),))
72+
73+
julia> grads_d = Diffractor.gradient(loss, model)
74+
(Tangent{Chain{Tuple{Embedding{Matrix{Float64}}, typeof(softmax)}}}(layers = (Tangent{Embedding{Matrix{Float64}}}(weight = [-0.18171549534589682 0.0 0.0; 0.18171549534589682 0.0 0.0],), ChainRulesCore.NoTangent()),),)
75+
76+
julia> grad_e = Enzyme.gradient(Reverse, loss, model)
77+
Chain(
78+
Embedding(3 => 2), # 6 parameters
79+
NNlib.softmax,
80+
)
81+
```
82+
83+
While the type returned for `∂loss_∂model` varies, they all have the same nested structure, matching that of the model. This is all that Flux needs.
84+
85+
```julia
86+
julia> grads_z[1].layers[1].weight
87+
2×3 Matrix{Float64}:
88+
-0.181715 0.0 0.0
89+
0.181715 0.0 0.0
90+
91+
julia> grad_e.layers[1].weight
92+
2×3 Matrix{Float64}:
93+
-0.181715 0.0 0.0
94+
0.181715 0.0 0.0
95+
```
96+
97+
Here's Flux updating the model using each gradient:
98+
<!--- perhaps we should trim this?? --->
99+
100+
```julia
101+
julia> opt = Flux.setup(Descent(1/3), model)
102+
(layers = ((weight = Leaf(Descent(0.333333), nothing),), ()),)
103+
104+
julia> Flux.update!(opt, deepcopy(model), grads_t[1])[2][1].weight
105+
2×3 Matrix{Float64}:
106+
1.06057 3.0 5.0
107+
1.93943 4.0 6.0
108+
109+
julia> Flux.update!(opt, deepcopy(model), grads_z[1])[2][1].weight
110+
2×3 Matrix{Float64}:
111+
1.06057 3.0 5.0
112+
1.93943 4.0 6.0
113+
114+
julia> Flux.update!(opt, deepcopy(model), grads_d[1])[2][1].weight
115+
2×3 Matrix{Float64}:
116+
1.06057 3.0 5.0
117+
1.93943 4.0 6.0
118+
119+
julia> Flux.update!(opt, deepcopy(model), grad_e)[2][1].weight
120+
2×3 Matrix{Float64}:
121+
1.06057 3.0 5.0
122+
1.93943 4.0 6.0
123+
```
124+
125+
In this case they are all identical, but there are some caveats, explored below.
126+
127+
128+
Aside, Tapir seems not to work just yet?
129+
```julia
130+
julia> Tapir_grad(f, xs...) = Tapir.value_and_pullback!!(Tapir.build_rrule(f, xs...), 1.0, f, xs...);
131+
132+
julia> _, grad_p = Tapir_grad(loss, model)
133+
(0.6067761335170363, (NoTangent(), Tangent{@NamedTuple{layers::Tuple{Tangent{@NamedTuple{weight::Matrix{Float64}}}, NoTangent}}}((layers = (Tangent{@NamedTuple{weight::Matrix{Float64}}}((weight = [0.0 0.0 0.0; 0.0 0.0 0.0],)), NoTangent()),))))
134+
135+
julia> grad_p.fields.layers[1].fields.weight
136+
2×3 Matrix{Float64}:
137+
0.0 0.0 0.0
138+
0.0 0.0 0.0
139+
```
140+
141+
<!--- I made an issue... perhaps fixed now?? --->
142+
143+
<hr/>
144+
145+
## Packages
146+
147+
Both Zygote and Tracker were written for Flux, and at present, Flux loads Zygote and exports `Zygote.gradient`, and calls this within `Flux.train!`. But apart from that, there is very little coupling between Flux and the automatic differentiation package.
148+
149+
This page has very brief notes on how all these packages compare, as a guide for anyone wanting to experiment with them. We stress "experiment" since Zygote is (at present) by far the best-tested.
150+
151+
### [Zygote.jl](https://github.com/FluxML/Zygote.jl/issues)
152+
153+
Source-to-source, within Julia.
154+
155+
* By far the best-tested option for Flux models.
156+
157+
* Long compilation times, on the first call.
158+
159+
* Allows mutation of structs, but not of arrays. This leads to the most common error... sometimes this happens because you mutate an array, often because you call some function which, internally, creates the array it wants to return & then fills it in.
160+
161+
* Custom rules via `ZygoteRules.@adjpoint` or better, `ChainRulesCore.rrule`.
162+
163+
* Returns nested NamedTuples and Tuples, and uses `nothing` to mean zero.
164+
165+
166+
### Zygote, implicit mode
167+
168+
Flux's default used to be work like this, instead of using deeply nested trees for gradients as above:
169+
170+
```julia
171+
julia> ps = Flux.params(model)
172+
Params([Float32[1.0 3.0 5.0; 2.0 4.0 6.0]])
173+
174+
julia> val, grad = Zygote.withgradient(() -> loss(model), ps)
175+
(val = 0.6067761f0, grad = Grads(...))
176+
177+
julia> grad[model.layers[1].weight] # dictionary, indexed by parameter arrays
178+
2×3 Matrix{Float32}:
179+
0.0 0.0 -0.181715
180+
0.0 0.0 0.181715
181+
```
182+
183+
The code inside Zygote is much the same -- do not expect large changes in speed, nor any changes in what works and what does not.
184+
185+
### [Tracker.jl](https://github.com/FluxML/Tracker.jl)
186+
187+
Uses a `TrackedArray` type to build a tape. The recommended interface `Tracker.withgradient` hides this, and works much like the Zygote one. Notice in particular that this cannot work:
188+
189+
```julia
190+
julia> val = loss(model) # computed outside gradient context
191+
0.6067761f0
192+
193+
julia> Tracker.withgradient(_ -> val, model) # this won't work!
194+
(val = 0.6067761f0, grad = (nothing,))
195+
```
196+
197+
Can be used in lower-level ways which directly expose the tracked types:
198+
199+
```julia
200+
julia> model_tracked = Flux.fmap(x -> x isa Array ? Tracker.param(x) : x, model)
201+
Chain(
202+
Embedding(3 => 2), # 6 parameters
203+
NNlib.softmax,
204+
)
205+
206+
julia> val_tracked = loss(model_tracked)
207+
0.6067761f0 (tracked)
208+
209+
julia> Tracker.back!(val_tracked)
210+
211+
julia> model_tracked.layers[1].weight.grad
212+
2×3 Matrix{Float32}:
213+
0.0 0.0 -0.181715
214+
0.0 0.0 0.181715
215+
```
216+
217+
* Quick to run, on the first call.
218+
219+
* Generally slower than Zygote, allocates more, and supports fewer operations.
220+
221+
* Custom rules via its own `track` and `@grad`.
222+
223+
224+
### [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)
225+
226+
New package which works on the LLVM code which Julia compiles down to.
227+
228+
* Allows mutation of arrays.
229+
230+
* Long compilation times, on the first call.
231+
232+
* Does not at present work on all Flux models, due to missing rules.
233+
234+
* Does not always handle type instability.
235+
236+
* Custom rules by its own rules... Generally fewer such rules than Zygote, and at a lower level -- applied to `BLAS.gemm!` not `*`.
237+
238+
* Returns another struct of the same type as the model, such as `Chain` above. Non-differentiable objects are left alone, not replaced by a zero.
239+
240+
### Tapir.jl
241+
242+
Another new AD to watch. Many similariries in its approach to Enzyme.jl, but operates all in Julia.
243+
244+
245+
### [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl)
246+
247+
To first approximation, Diffractor may be thought of as a re-write of Zygote, aiming to reduce compilation times, and to handle higher-order derivatives much better.
248+
249+
At present, development is focused on the forward-mode part. Reverse-mode `gradient` exists,
250+
but fails on many Flux models.
251+
252+
* Custom rules via `ChainRulesCore.rrule`.
253+
254+
* Returns nested `Tangent` types, from ChainRulesCore, with zeros indicated by `NoTangent()`.
255+
256+
257+
### [Yota.jl](https://github.com/dfdx/Yota.jl)
258+
259+
Another Julia source-to-source reverse-mode AD.
260+
261+
* Does not work on Julia 1.10
262+
263+
* Does not handle branches based on runtime values, due to how its tape works.
264+
265+
* Custom rules via `ChainRulesCore.rrule`.
266+
267+
* Returns nested `Tangent` types, from ChainRulesCore
268+
269+
270+
### [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
271+
272+
Forward mode is a different algorithm...
273+
274+
* Needs a flat vector
275+
276+
* Forward mode is generally not what you want!
277+
278+
* `gradient(f, x)` will call `f(x)` multiple times. Layers like `BatchNorm` with state may get confused.
279+
280+
281+
### ReverseDiff.jl
282+
283+
* Like Tracker this passes a special TrackedArray type through your function. Allows you to record & compile the tape, and pre-allocate things.
284+
285+
* Needs a flat vector
286+
287+
* No support for GPU
288+
289+
290+
291+
<hr/>
292+
293+
## Second-order
294+
295+
If you calculate some `gradient(f, x)` inside the loss function, then `f` needs to be differentiated twice for the final result.
296+
297+
### Zygote over Zygote
298+
299+
In principle this works but in practice... best start small.
300+
301+
### ForwardDiff over Zygote
302+
303+
Zygote.hessian is like this.
304+
305+
### Enzyme.jl
306+
307+
I haven't tried really, but I think it ought to work.
308+
309+
<hr/>
310+
311+
## Meta-packages
312+
313+
Besides AD packages, several packages have been written aiming to provide a unified interface to many options. These may offer useful ways to quickly switch between things you are trying.
314+
315+
### [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl)
316+
317+
The original meta-package?
318+
319+
### [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)
320+
321+
This year's new attempt to build a simpler one?
322+
323+
### [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)
324+
325+
Really `rrule_via_ad` is another mechanism, but only for 3 systems.
326+
327+
328+
329+
330+

0 commit comments

Comments
 (0)