|
| 1 | +# Which functions need rules? |
| 2 | + |
| 3 | +In principle, a perfect AD system only needs rules for basic operations and can infer the rules for more complicated functions automatically. |
| 4 | +In practice, performance needs to be considered as well. |
| 5 | + |
| 6 | +Some functions use `ccall` internally, for example [`^`](https://github.com/JuliaLang/julia/blob/v1.5.3/base/math.jl#L886). |
| 7 | +These functions cannot be differentiated through by AD systems, and need custom rules. |
| 8 | + |
| 9 | +Other functions can in principle be differentiated through by an AD system, but there exists a mathematical insight that can dramatically improve the computation of the derivative. |
| 10 | +An example is numerical integration, where writing a rule implementing the [fundamental theorem of calculus](https://en.wikipedia.org/wiki/Fundamental_theorem_of_calculus) removes the need to perform AD through numerical integration. |
| 11 | + |
| 12 | +Furthermore, AD systems make different trade-offs in performance due to their design. |
| 13 | +This means that a certain rule will help one AD system, but not improve (and also not harm) another. |
| 14 | +Below, we list some patterns relevant for the [Zygote.jl](https://github.com/FluxML/Zygote.jl) AD system. |
| 15 | + |
| 16 | +Rules for functions which mutate its arguments, e.g. `sort!`, should not be written at the moment. |
| 17 | +While technically they are supported, they would break [Zygote.jl](https://github.com/FluxML/Zygote.jl) such that [it would sometimes quietly return the wrong answer](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/242). |
| 18 | +This may be resolved in the future by [allowing AD systems to opt-in or opt-out of certain types of rules](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/270). |
| 19 | + |
| 20 | +### Patterns that need rules in [Zygote.jl](https://github.com/FluxML/Zygote.jl) |
| 21 | + |
| 22 | +There are a few classes of functions that Zygote cannot differentiate through. |
| 23 | +Custom rules will need to be written for these to make AD work. |
| 24 | + |
| 25 | +Other patterns can be AD'ed through, but the backward pass performance can be greatly improved by writing a rule. |
| 26 | + |
| 27 | +#### Functions which mutate arrays |
| 28 | +For example, |
| 29 | +```julia |
| 30 | +function addone!(array) |
| 31 | + array .+= 1 |
| 32 | + return sum(array) |
| 33 | +end |
| 34 | +``` |
| 35 | +complains that |
| 36 | +```julia |
| 37 | +julia> using Zygote |
| 38 | +julia> gradient(addone!, a) |
| 39 | +ERROR: Mutating arrays is not supported |
| 40 | +``` |
| 41 | +However, upon adding the `rrule` (restart the REPL after calling `gradient`) |
| 42 | +```julia |
| 43 | +function ChainRules.rrule(::typeof(addone!), a) |
| 44 | + y = addone!(a) |
| 45 | + function addone!_pullback(ȳ) |
| 46 | + return NoTangent(), ones(length(a)) |
| 47 | + end |
| 48 | + return y, addone!_pullback |
| 49 | +end |
| 50 | +``` |
| 51 | +the gradient can be evaluated: |
| 52 | +```julia |
| 53 | +julia> gradient(addone!, a) |
| 54 | +([1.0, 1.0, 1.0],) |
| 55 | +``` |
| 56 | + |
| 57 | +!!! note "Why restarting REPL after calling `gradient`?" |
| 58 | + When `gradient` is called in `Zygote` for a function with no `rrule` defined, a backward pass for the function call is generated and cached. |
| 59 | + When `gradient` is called for the second time on the same function signature, the backward pass is reused without checking whether an an `rrule` has been defined between the two calls to `gradient`. |
| 60 | + |
| 61 | + If an `rrule` is defined before the first call to `gradient` it should register the rule and use it, but that prevents comparing what happens before and after the `rrule` is defined. |
| 62 | + To compare both versions with and without an `rrule` in the REPL simultaneously, define a function `f(x) = <body>` (no `rrule`), another function `f_cr(x) = f(x)`, and an `rrule` for `f_cr`. |
| 63 | + |
| 64 | +#### Exception handling |
| 65 | + |
| 66 | +Zygote does not support differentiating through `try`/`catch` statements. |
| 67 | +For example, differentiating through |
| 68 | +```julia |
| 69 | +function exception(x) |
| 70 | + try |
| 71 | + return x^2 |
| 72 | + catch e |
| 73 | + println("could not square input") |
| 74 | + throw(e) |
| 75 | + end |
| 76 | +end |
| 77 | +``` |
| 78 | +does not work |
| 79 | +```julia |
| 80 | +julia> gradient(exception, 3.0) |
| 81 | +ERROR: Compiling Tuple{typeof(exception),Int64}: try/catch is not supported. |
| 82 | +``` |
| 83 | +without an `rrule` defined (restart the REPL after calling `gradient`) |
| 84 | +```julia |
| 85 | +function ChainRulesCore.rrule(::typeof(exception), x) |
| 86 | + y = exception(x) |
| 87 | + function exception_pullback(ȳ) |
| 88 | + return NoTangent(), 2*x |
| 89 | + end |
| 90 | + return y, exception_pullback |
| 91 | +end |
| 92 | +``` |
| 93 | + |
| 94 | +```julia |
| 95 | +julia> gradient(exception, 3.0) |
| 96 | +(6.0,) |
| 97 | +``` |
| 98 | + |
| 99 | + |
| 100 | +#### Loops |
| 101 | + |
| 102 | +Julia runs loops fast. |
| 103 | +Unfortunately Zygote differentiates through loops slowly. |
| 104 | +So, for example, computing the mean squared error by using a loop |
| 105 | +```julia |
| 106 | +function mse(y, ŷ) |
| 107 | + N = length(y) |
| 108 | + s = 0.0 |
| 109 | + for i in 1:N |
| 110 | + s += (y[i] - ŷ[i])^2.0 |
| 111 | + end |
| 112 | + return s/N |
| 113 | +end |
| 114 | +``` |
| 115 | +takes a lot longer to AD through |
| 116 | +```julia |
| 117 | +julia> y = rand(30) |
| 118 | +julia> ŷ = rand(30) |
| 119 | +julia> @btime gradient(mse, $y, $ŷ) |
| 120 | + 38.180 μs (993 allocations: 65.00 KiB) |
| 121 | +``` |
| 122 | +than if we supply an `rrule`, (restart the REPL after calling `gradient`) |
| 123 | +```julia |
| 124 | +function ChainRules.rrule(::typeof(mse), x, x̂) |
| 125 | + output = mse(x, x̂) |
| 126 | + function mse_pullback(ȳ) |
| 127 | + N = length(x) |
| 128 | + g = (2 ./ N) .* (x .- x̂) .* ȳ |
| 129 | + return NoTangent(), g, -g |
| 130 | + end |
| 131 | + return output, mse_pullback |
| 132 | +end |
| 133 | +``` |
| 134 | +which is much faster |
| 135 | +```julia |
| 136 | +julia> @btime gradient(mse, $y, $ŷ) |
| 137 | + 143.697 ns (2 allocations: 672 bytes) |
| 138 | +``` |
| 139 | + |
| 140 | +#### Inplace accumulation |
| 141 | + |
| 142 | +Inplace accumulation of gradients is slow in `Zygote`. |
| 143 | +The issue, demonstrated in the folowing example, is that the gradient of `getindex` allocates an array of zeros with a single non-zero element. |
| 144 | +```julia |
| 145 | +function sum3(array) |
| 146 | + x = array[1] |
| 147 | + y = array[2] |
| 148 | + z = array[3] |
| 149 | + return x+y+z |
| 150 | +end |
| 151 | +``` |
| 152 | +```julia |
| 153 | +julia> @btime gradient(sum3, rand(30)) |
| 154 | + 424.510 ns (9 allocations: 2.06 KiB) |
| 155 | +``` |
| 156 | +Computing the gradient with only a single array allocation using an `rrule` (restart the REPL after calling `gradient`) |
| 157 | +```julia |
| 158 | +function ChainRulesCore.rrule(::typeof(sum3), a) |
| 159 | + y = sum3(a) |
| 160 | + function sum3_pullback(ȳ) |
| 161 | + grad = zeros(length(a)) |
| 162 | + grad[1:3] .+= ȳ |
| 163 | + return NoTangent(), grad |
| 164 | + end |
| 165 | + return y, sum3_pullback |
| 166 | +end |
| 167 | +``` |
| 168 | +turns out to be significantly faster |
| 169 | +```julia |
| 170 | +julia> @btime gradient(sum3, rand(30)) |
| 171 | + 192.818 ns (3 allocations: 784 bytes) |
| 172 | +``` |
0 commit comments