You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
provides tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl).
108
+
Take a look at the documentation or the existing [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) tests to see how to write the tests.
109
109
110
110
!!! warning
111
-
Use finite differencing to test derivatives.
112
111
Don't use analytical derivations for derivatives in the tests.
113
112
Those are what you use to define the rules, and so can not be confidently used in the test.
114
113
If you misread/misunderstood them, then your tests/implementation will have the same mistake.
114
+
Use finite differencing methods instead, as they are based on the primal computation.
115
115
116
116
## CAS systems are your friends.
117
117
118
118
It is very easy to check gradients or derivatives with a computer algebra system (CAS) like [WolframAlpha](https://www.wolframalpha.com/input/?i=gradient+atan2%28x%2Cy%29).
119
+
120
+
## Which functions need rules?
121
+
122
+
In principle, a perfect AD system only needs rules for basic operations and can infer the rules for more complicated functions automatically.
123
+
In practice, performance needs to be considered as well.
124
+
125
+
Some functions use `ccall` internally, for example [`^`](https://github.com/JuliaLang/julia/blob/v1.5.3/base/math.jl#L886).
126
+
These functions can not be differentiated through by AD systems, and need custom rules.
127
+
128
+
Other functions can in principle be differentiated through by an AD system, but there exists a mathematical insight that can dramatically improve the computation of the derivative.
129
+
An example is numerical integration, where writing a rule removes the need to perform AD through numerical integration.
130
+
131
+
Furthermore, AD systems make different trade-offs in performance due to their design.
132
+
This means that a certain rule will help one AD system, but not improve (and also not harm) another.
133
+
Below, we list some patterns relevant for the [Zygote.jl](https://github.com/FluxML/Zygote.jl) AD system.
134
+
135
+
### Patterns that need rules in [Zygote.jl](https://github.com/FluxML/Zygote.jl)
136
+
137
+
There are a few classes of functions that Zygote can not differentiate through.
138
+
Custom rules will need to be written for these to make AD work.
139
+
140
+
Other patterns can be AD'ed through, but the backward pass performance can be greatly improved by writing a rule.
141
+
142
+
#### Functions which mutate arrays
143
+
For example,
144
+
```julia
145
+
functionaddone!(array)
146
+
array .+=1
147
+
returnsum(array)
148
+
end
149
+
```
150
+
complains that
151
+
```julia
152
+
julia>using Zygote
153
+
julia>gradient(addone!, a)
154
+
ERROR: Mutating arrays is not supported
155
+
```
156
+
However, upon adding the `rrule` (restart the REPL after calling `gradient`)
157
+
```julia
158
+
function ChainRules.rrule(::typeof(addone!), a)
159
+
y =addone!(a)
160
+
functionaddone!_pullback(ȳ)
161
+
return NO_FIELDS, ones(length(a))
162
+
end
163
+
return y, addone!_pullback
164
+
end
165
+
```
166
+
the gradient can be evaluated:
167
+
```julia
168
+
julia>gradient(addone!, a)
169
+
([1.0, 1.0, 1.0],)
170
+
```
171
+
172
+
!!! note "Why restarting REPL after calling `gradient`?"
173
+
When `gradient` is called in `Zygote` for a function with no `rrule` defined, a backward pass for the function call is generated and cached.
174
+
When `gradient` is called for the second time on the same function signature, the backward pass is reused without checking whether an an `rrule` has been defined between the two calls to `gradient`.
175
+
176
+
If an `rrule` is defined before the first call to `gradient` it should register the rule and use it, but that prevents comparing what happens before and after the `rrule` is defined.
177
+
To compare both versions with and without an `rrule` in the REPL simultaneously, define a function `f(x) = <body>` (no `rrule`), another function `f_cr(x) = f(x)`, and an `rrule` for `f_cr`.
178
+
179
+
#### Exception handling
180
+
181
+
Zygote does not support differentiating through `try`/`catch` statements.
182
+
For example, differentiating through
183
+
```julia
184
+
functionexception(x)
185
+
try
186
+
return x^2
187
+
catch e
188
+
println("could not square input")
189
+
throw(e)
190
+
end
191
+
end
192
+
```
193
+
does not work
194
+
```julia
195
+
julia>gradient(exception, 3.0)
196
+
ERROR: Compiling Tuple{typeof(exception),Int64}:try/catch is not supported.
197
+
```
198
+
without an `rrule` defined (restart the REPL after calling `gradient`)
199
+
```julia
200
+
function ChainRulesCore.rrule(::typeof(exception), x)
201
+
y =exception(x)
202
+
functionexception_pullback(ȳ)
203
+
return NO_FIELDS, 2*x
204
+
end
205
+
return y, exception_pullback
206
+
end
207
+
```
208
+
209
+
```julia
210
+
julia>gradient(exception, 3.0)
211
+
(6.0,)
212
+
```
213
+
214
+
215
+
#### Loops
216
+
217
+
Julia runs loops fast.
218
+
Unfortunately Zygote differentiates through loops slowly.
219
+
So, for example, computing the mean squared error by using a loop
220
+
```julia
221
+
functionmse(y, ŷ)
222
+
N =length(y)
223
+
s =0.0
224
+
for i in1:N
225
+
s += (y[i] - ŷ[i])^2.0
226
+
end
227
+
return s/N
228
+
end
229
+
```
230
+
takes a lot longer to AD through
231
+
```julia
232
+
julia> y =rand(30)
233
+
julia> ŷ =rand(30)
234
+
julia>@btimegradient(mse, $y, $ŷ)
235
+
38.180 μs (993 allocations:65.00 KiB)
236
+
```
237
+
than if we supply an `rrule`, (restart the REPL after calling `gradient`)
238
+
```julia
239
+
function ChainRules.rrule(::typeof(mse), x, x̂)
240
+
output =mse(x, x̂)
241
+
functionmse_pullback(ȳ)
242
+
N =length(x)
243
+
g = (2./ N) .* (x .- x̂) .* ȳ
244
+
return NO_FIELDS, g, -g
245
+
end
246
+
return output, mse_pullback
247
+
end
248
+
```
249
+
which is much faster
250
+
```julia
251
+
julia>@btimegradient(mse, $y, $ŷ)
252
+
143.697 ns (2 allocations:672 bytes)
253
+
```
254
+
255
+
#### Inplace accumulation
256
+
257
+
Inplace accumulation of gradients is slow in `Zygote`.
258
+
The issue, demonstrated in the folowing example, is that the gradient of `getindex` allocates an array of zeros with a single non-zero element.
259
+
```julia
260
+
functionsum3(array)
261
+
x = array[1]
262
+
y = array[2]
263
+
z = array[3]
264
+
return x+y+z
265
+
end
266
+
```
267
+
```julia
268
+
julia>@btimegradient(sum3, rand(30))
269
+
424.510 ns (9 allocations:2.06 KiB)
270
+
```
271
+
Computing the gradient with only a single array allocation using an `rrule` (restart the REPL after calling `gradient`)
0 commit comments