Skip to content

Commit c203e94

Browse files
authored
break up writing good rules (#543)
1 parent 3e11e62 commit c203e94

File tree

4 files changed

+209
-206
lines changed

4 files changed

+209
-206
lines changed

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ makedocs(;
5252
"Introduction" => "rule_author/intro.md",
5353
"Pedagogical example" => "rule_author/example.md",
5454
"Tangent types" => "rule_author/tangents.md",
55-
#"`frule` and `rrule`" => "rule_author/rules.md", # TODO: a complete example
55+
"Which functions need rules?" => "rule_author/which_functions_need_rules.md",
56+
"Rule definition tools" => "rule_author/rule_definition_tools.md",
5657
"Writing good rules" => "rule_author/writing_good_rules.md",
5758
"Testing your rules" => "rule_author/testing.md",
5859
"Superpowers" => [
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# [Using rule definition tools](@id ruletools)
2+
3+
Rule definition tools can help you write more `frule`s and the `rrule`s with less lines of code.
4+
5+
## [`@non_differentiable`](@ref)
6+
7+
For non-differentiable functions the [`@non_differentiable`](@ref) macro can be used.
8+
For example, instead of manually defining the `frule` and the `rrule` for string concatenation `*(String..)`, the macro call
9+
```julia
10+
@non_differentiable *(String...)
11+
```
12+
defines the following `frule` and `rrule` automatically
13+
```julia
14+
function ChainRulesCore.frule(var"##_#1600", ::Core.Typeof(*), String::Any...; kwargs...)
15+
return (*(String...; kwargs...), NoTangent())
16+
end
17+
function ChainRulesCore.rrule(::Core.Typeof(*), String::Any...; kwargs...)
18+
return (*(String...; kwargs...), function var"*_pullback"(_)
19+
(ZeroTangent(), ntuple((_->NoTangent()), 0 + length(String))...)
20+
end)
21+
end
22+
```
23+
Note that the types of arguments are propagated to the `frule` and `rrule` definitions.
24+
This is needed in case the function differentiable for some but not for other types of arguments.
25+
For example `*(1, 2, 3)` is differentiable, and is not defined with the macro call above.
26+
27+
## [`@scalar_rule`](@ref)
28+
29+
For functions involving only scalars, i.e. subtypes of `Number` (no `struct`s, `String`s...), both the `frule` and the `rrule` can be defined using a single [`@scalar_rule`](@ref) macro call.
30+
31+
Note that the function does not have to be $\mathbb{R} \rightarrow \mathbb{R}$.
32+
In fact, any number of scalar arguments is supported, as is returning a tuple of scalars.
33+
34+
See docstrings for the comprehensive usage instructions.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Which functions need rules?
2+
3+
In principle, a perfect AD system only needs rules for basic operations and can infer the rules for more complicated functions automatically.
4+
In practice, performance needs to be considered as well.
5+
6+
Some functions use `ccall` internally, for example [`^`](https://github.com/JuliaLang/julia/blob/v1.5.3/base/math.jl#L886).
7+
These functions cannot be differentiated through by AD systems, and need custom rules.
8+
9+
Other functions can in principle be differentiated through by an AD system, but there exists a mathematical insight that can dramatically improve the computation of the derivative.
10+
An example is numerical integration, where writing a rule implementing the [fundamental theorem of calculus](https://en.wikipedia.org/wiki/Fundamental_theorem_of_calculus) removes the need to perform AD through numerical integration.
11+
12+
Furthermore, AD systems make different trade-offs in performance due to their design.
13+
This means that a certain rule will help one AD system, but not improve (and also not harm) another.
14+
Below, we list some patterns relevant for the [Zygote.jl](https://github.com/FluxML/Zygote.jl) AD system.
15+
16+
Rules for functions which mutate its arguments, e.g. `sort!`, should not be written at the moment.
17+
While technically they are supported, they would break [Zygote.jl](https://github.com/FluxML/Zygote.jl) such that [it would sometimes quietly return the wrong answer](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/242).
18+
This may be resolved in the future by [allowing AD systems to opt-in or opt-out of certain types of rules](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/270).
19+
20+
### Patterns that need rules in [Zygote.jl](https://github.com/FluxML/Zygote.jl)
21+
22+
There are a few classes of functions that Zygote cannot differentiate through.
23+
Custom rules will need to be written for these to make AD work.
24+
25+
Other patterns can be AD'ed through, but the backward pass performance can be greatly improved by writing a rule.
26+
27+
#### Functions which mutate arrays
28+
For example,
29+
```julia
30+
function addone!(array)
31+
array .+= 1
32+
return sum(array)
33+
end
34+
```
35+
complains that
36+
```julia
37+
julia> using Zygote
38+
julia> gradient(addone!, a)
39+
ERROR: Mutating arrays is not supported
40+
```
41+
However, upon adding the `rrule` (restart the REPL after calling `gradient`)
42+
```julia
43+
function ChainRules.rrule(::typeof(addone!), a)
44+
y = addone!(a)
45+
function addone!_pullback(ȳ)
46+
return NoTangent(), ones(length(a))
47+
end
48+
return y, addone!_pullback
49+
end
50+
```
51+
the gradient can be evaluated:
52+
```julia
53+
julia> gradient(addone!, a)
54+
([1.0, 1.0, 1.0],)
55+
```
56+
57+
!!! note "Why restarting REPL after calling `gradient`?"
58+
When `gradient` is called in `Zygote` for a function with no `rrule` defined, a backward pass for the function call is generated and cached.
59+
When `gradient` is called for the second time on the same function signature, the backward pass is reused without checking whether an an `rrule` has been defined between the two calls to `gradient`.
60+
61+
If an `rrule` is defined before the first call to `gradient` it should register the rule and use it, but that prevents comparing what happens before and after the `rrule` is defined.
62+
To compare both versions with and without an `rrule` in the REPL simultaneously, define a function `f(x) = <body>` (no `rrule`), another function `f_cr(x) = f(x)`, and an `rrule` for `f_cr`.
63+
64+
#### Exception handling
65+
66+
Zygote does not support differentiating through `try`/`catch` statements.
67+
For example, differentiating through
68+
```julia
69+
function exception(x)
70+
try
71+
return x^2
72+
catch e
73+
println("could not square input")
74+
throw(e)
75+
end
76+
end
77+
```
78+
does not work
79+
```julia
80+
julia> gradient(exception, 3.0)
81+
ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported.
82+
```
83+
without an `rrule` defined (restart the REPL after calling `gradient`)
84+
```julia
85+
function ChainRulesCore.rrule(::typeof(exception), x)
86+
y = exception(x)
87+
function exception_pullback(ȳ)
88+
return NoTangent(), 2*x
89+
end
90+
return y, exception_pullback
91+
end
92+
```
93+
94+
```julia
95+
julia> gradient(exception, 3.0)
96+
(6.0,)
97+
```
98+
99+
100+
#### Loops
101+
102+
Julia runs loops fast.
103+
Unfortunately Zygote differentiates through loops slowly.
104+
So, for example, computing the mean squared error by using a loop
105+
```julia
106+
function mse(y, ŷ)
107+
N = length(y)
108+
s = 0.0
109+
for i in 1:N
110+
s += (y[i] - ŷ[i])^2.0
111+
end
112+
return s/N
113+
end
114+
```
115+
takes a lot longer to AD through
116+
```julia
117+
julia> y = rand(30)
118+
julia> ŷ = rand(30)
119+
julia> @btime gradient(mse, $y, $ŷ)
120+
38.180 μs (993 allocations: 65.00 KiB)
121+
```
122+
than if we supply an `rrule`, (restart the REPL after calling `gradient`)
123+
```julia
124+
function ChainRules.rrule(::typeof(mse), x, x̂)
125+
output = mse(x, x̂)
126+
function mse_pullback(ȳ)
127+
N = length(x)
128+
g = (2 ./ N) .* (x .- x̂) .* ȳ
129+
return NoTangent(), g, -g
130+
end
131+
return output, mse_pullback
132+
end
133+
```
134+
which is much faster
135+
```julia
136+
julia> @btime gradient(mse, $y, $ŷ)
137+
143.697 ns (2 allocations: 672 bytes)
138+
```
139+
140+
#### Inplace accumulation
141+
142+
Inplace accumulation of gradients is slow in `Zygote`.
143+
The issue, demonstrated in the folowing example, is that the gradient of `getindex` allocates an array of zeros with a single non-zero element.
144+
```julia
145+
function sum3(array)
146+
x = array[1]
147+
y = array[2]
148+
z = array[3]
149+
return x+y+z
150+
end
151+
```
152+
```julia
153+
julia> @btime gradient(sum3, rand(30))
154+
424.510 ns (9 allocations: 2.06 KiB)
155+
```
156+
Computing the gradient with only a single array allocation using an `rrule` (restart the REPL after calling `gradient`)
157+
```julia
158+
function ChainRulesCore.rrule(::typeof(sum3), a)
159+
y = sum3(a)
160+
function sum3_pullback(ȳ)
161+
grad = zeros(length(a))
162+
grad[1:3] .+= ȳ
163+
return NoTangent(), grad
164+
end
165+
return y, sum3_pullback
166+
end
167+
```
168+
turns out to be significantly faster
169+
```julia
170+
julia> @btime gradient(sum3, rand(30))
171+
192.818 ns (3 allocations: 784 bytes)
172+
```

0 commit comments

Comments
 (0)