Skip to content

Commit 37a6fec

Browse files
authored
Merge pull request #285 from JuliaDiff/mz/docs
add "when to write rules" advice
2 parents 9fad48c + 3468968 commit 37a6fec

File tree

3 files changed

+197
-8
lines changed

3 files changed

+197
-8
lines changed

docs/Manifest.toml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,25 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
1010
version = "0.5.10"
1111

1212
[[ChainRulesCore]]
13-
deps = ["LinearAlgebra", "SparseArrays"]
13+
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.24"
16+
version = "0.9.26"
17+
18+
[[Compat]]
19+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20+
git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b"
21+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22+
version = "3.25.0"
1723

1824
[[Dates]]
1925
deps = ["Printf"]
2026
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
2127

28+
[[DelimitedFiles]]
29+
deps = ["Mmap"]
30+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
31+
2232
[[Distributed]]
2333
deps = ["Random", "Serialization", "Sockets"]
2434
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -115,13 +125,21 @@ version = "0.1.0"
115125
[[Serialization]]
116126
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
117127

128+
[[SharedArrays]]
129+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
130+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
131+
118132
[[Sockets]]
119133
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
120134

121135
[[SparseArrays]]
122136
deps = ["LinearAlgebra", "Random"]
123137
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
124138

139+
[[Statistics]]
140+
deps = ["LinearAlgebra", "SparseArrays"]
141+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
142+
125143
[[Test]]
126144
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
127145
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
56
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"

docs/src/writing_good_rules.md

Lines changed: 176 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ Rule definition tools can help you write more `frule`s and the `rrule`s with les
7575

7676
For non-differentiable functions the [`@non_differentiable`](@ref) macro can be used.
7777
For example, instead of manually defining the `frule` and the `rrule` for string concatenation `*(String..)`, the macro call
78-
```
78+
```julia
7979
@non_differentiable *(String...)
8080
```
8181
defines the following `frule` and `rrule` automatically
82-
```
82+
```julia
8383
function ChainRulesCore.frule(var"##_#1600", ::Core.Typeof(*), String::Any...; kwargs...)
8484
return (*(String...; kwargs...), DoesNotExist())
8585
end
@@ -103,16 +103,186 @@ In fact, any number of scalar arguments is supported, as is returning a tuple of
103103
See docstrings for the comprehensive usage instructions.
104104
## Write tests
105105

106-
In [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl)
107-
there are fairly decent tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl).
108-
Take a look at existing [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) tests and you should see how to do stuff.
106+
[ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl)
107+
provides tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl).
108+
Take a look at the documentation or the existing [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) tests to see how to write the tests.
109109

110110
!!! warning
111-
Use finite differencing to test derivatives.
112111
Don't use analytical derivations for derivatives in the tests.
113112
Those are what you use to define the rules, and so can not be confidently used in the test.
114113
If you misread/misunderstood them, then your tests/implementation will have the same mistake.
114+
Use finite differencing methods instead, as they are based on the primal computation.
115115

116116
## CAS systems are your friends.
117117

118118
It is very easy to check gradients or derivatives with a computer algebra system (CAS) like [WolframAlpha](https://www.wolframalpha.com/input/?i=gradient+atan2%28x%2Cy%29).
119+
120+
## Which functions need rules?
121+
122+
In principle, a perfect AD system only needs rules for basic operations and can infer the rules for more complicated functions automatically.
123+
In practice, performance needs to be considered as well.
124+
125+
Some functions use `ccall` internally, for example [`^`](https://github.com/JuliaLang/julia/blob/v1.5.3/base/math.jl#L886).
126+
These functions can not be differentiated through by AD systems, and need custom rules.
127+
128+
Other functions can in principle be differentiated through by an AD system, but there exists a mathematical insight that can dramatically improve the computation of the derivative.
129+
An example is numerical integration, where writing a rule removes the need to perform AD through numerical integration.
130+
131+
Furthermore, AD systems make different trade-offs in performance due to their design.
132+
This means that a certain rule will help one AD system, but not improve (and also not harm) another.
133+
Below, we list some patterns relevant for the [Zygote.jl](https://github.com/FluxML/Zygote.jl) AD system.
134+
135+
### Patterns that need rules in [Zygote.jl](https://github.com/FluxML/Zygote.jl)
136+
137+
There are a few classes of functions that Zygote can not differentiate through.
138+
Custom rules will need to be written for these to make AD work.
139+
140+
Other patterns can be AD'ed through, but the backward pass performance can be greatly improved by writing a rule.
141+
142+
#### Functions which mutate arrays
143+
For example,
144+
```julia
145+
function addone!(array)
146+
array .+= 1
147+
return sum(array)
148+
end
149+
```
150+
complains that
151+
```julia
152+
julia> using Zygote
153+
julia> gradient(addone!, a)
154+
ERROR: Mutating arrays is not supported
155+
```
156+
However, upon adding the `rrule` (restart the REPL after calling `gradient`)
157+
```julia
158+
function ChainRules.rrule(::typeof(addone!), a)
159+
y = addone!(a)
160+
function addone!_pullback(ȳ)
161+
return NO_FIELDS, ones(length(a))
162+
end
163+
return y, addone!_pullback
164+
end
165+
```
166+
the gradient can be evaluated:
167+
```julia
168+
julia> gradient(addone!, a)
169+
([1.0, 1.0, 1.0],)
170+
```
171+
172+
!!! note "Why restarting REPL after calling `gradient`?"
173+
When `gradient` is called in `Zygote` for a function with no `rrule` defined, a backward pass for the function call is generated and cached.
174+
When `gradient` is called for the second time on the same function signature, the backward pass is reused without checking whether an an `rrule` has been defined between the two calls to `gradient`.
175+
176+
If an `rrule` is defined before the first call to `gradient` it should register the rule and use it, but that prevents comparing what happens before and after the `rrule` is defined.
177+
To compare both versions with and without an `rrule` in the REPL simultaneously, define a function `f(x) = <body>` (no `rrule`), another function `f_cr(x) = f(x)`, and an `rrule` for `f_cr`.
178+
179+
#### Exception handling
180+
181+
Zygote does not support differentiating through `try`/`catch` statements.
182+
For example, differentiating through
183+
```julia
184+
function exception(x)
185+
try
186+
return x^2
187+
catch e
188+
println("could not square input")
189+
throw(e)
190+
end
191+
end
192+
```
193+
does not work
194+
```julia
195+
julia> gradient(exception, 3.0)
196+
ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported.
197+
```
198+
without an `rrule` defined (restart the REPL after calling `gradient`)
199+
```julia
200+
function ChainRulesCore.rrule(::typeof(exception), x)
201+
y = exception(x)
202+
function exception_pullback(ȳ)
203+
return NO_FIELDS, 2*x
204+
end
205+
return y, exception_pullback
206+
end
207+
```
208+
209+
```julia
210+
julia> gradient(exception, 3.0)
211+
(6.0,)
212+
```
213+
214+
215+
#### Loops
216+
217+
Julia runs loops fast.
218+
Unfortunately Zygote differentiates through loops slowly.
219+
So, for example, computing the mean squared error by using a loop
220+
```julia
221+
function mse(y, ŷ)
222+
N = length(y)
223+
s = 0.0
224+
for i in 1:N
225+
s += (y[i] - ŷ[i])^2.0
226+
end
227+
return s/N
228+
end
229+
```
230+
takes a lot longer to AD through
231+
```julia
232+
julia> y = rand(30)
233+
julia>= rand(30)
234+
julia> @btime gradient(mse, $y, $ŷ)
235+
38.180 μs (993 allocations: 65.00 KiB)
236+
```
237+
than if we supply an `rrule`, (restart the REPL after calling `gradient`)
238+
```julia
239+
function ChainRules.rrule(::typeof(mse), x, x̂)
240+
output = mse(x, x̂)
241+
function mse_pullback(ȳ)
242+
N = length(x)
243+
g = (2 ./ N) .* (x .- x̂) .*
244+
return NO_FIELDS, g, -g
245+
end
246+
return output, mse_pullback
247+
end
248+
```
249+
which is much faster
250+
```julia
251+
julia> @btime gradient(mse, $y, $ŷ)
252+
143.697 ns (2 allocations: 672 bytes)
253+
```
254+
255+
#### Inplace accumulation
256+
257+
Inplace accumulation of gradients is slow in `Zygote`.
258+
The issue, demonstrated in the folowing example, is that the gradient of `getindex` allocates an array of zeros with a single non-zero element.
259+
```julia
260+
function sum3(array)
261+
x = array[1]
262+
y = array[2]
263+
z = array[3]
264+
return x+y+z
265+
end
266+
```
267+
```julia
268+
julia> @btime gradient(sum3, rand(30))
269+
424.510 ns (9 allocations: 2.06 KiB)
270+
```
271+
Computing the gradient with only a single array allocation using an `rrule` (restart the REPL after calling `gradient`)
272+
```julia
273+
function ChainRulesCore.rrule(::typeof(sum3), a)
274+
y = sum3(a)
275+
function sum3_pullback(ȳ)
276+
grad = zeros(length(a))
277+
grad[1:3] .+= 1.0
278+
return NO_FIELDS, grad
279+
end
280+
return y, sum3_pullback
281+
end
282+
```
283+
turns out to be significantly faster
284+
```julia
285+
julia> @btime gradient(sum3, rand(30))
286+
192.818 ns (3 allocations: 784 bytes)
287+
```
288+

0 commit comments

Comments
 (0)