|
| 1 | +# Pedagogical Example |
| 2 | + |
| 3 | +This pedagogical example will show you how to write an `rrule`. |
| 4 | +See [On writing good `rrule` / `frule` methods](@ref) section for more tips and gotchas. |
| 5 | +If you want to learn about `frule`s, you should still read and understand this example as many concepts are shared, and then look for real world `frule` examples in ChainRules.jl. |
| 6 | + |
| 7 | +## The primal |
| 8 | + |
| 9 | +We define a struct `Foo` |
| 10 | +```julia |
| 11 | +struct Foo |
| 12 | + A::Matrix |
| 13 | + c::Float64 |
| 14 | +end |
| 15 | +``` |
| 16 | +and a function that multiplies `Foo` with an `AbstractArray`: |
| 17 | +```julia |
| 18 | +function foo_mul(foo::Foo, b::AbstractArray) |
| 19 | + return foo.A * b |
| 20 | +end |
| 21 | +``` |
| 22 | +Note that field `c` is ignored in the calculation. |
| 23 | + |
| 24 | +## The `rrule` |
| 25 | + |
| 26 | +The `rrule` method for our primal computation should extend the `ChainRulesCore.rrule` function. |
| 27 | +```julia |
| 28 | +function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray) |
| 29 | + y = foo_mul(foo, b) |
| 30 | + function foo_mul_pullback(ȳ) |
| 31 | + f̄ = NoTangent() |
| 32 | + f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent()) |
| 33 | + b̄ = @thunk(foo.A' * ȳ) |
| 34 | + return f̄, f̄oo, b̄ |
| 35 | + end |
| 36 | + return y, foo_mul_pullback |
| 37 | +end |
| 38 | +``` |
| 39 | +Now let's examine the rule in more detail: |
| 40 | +```julia |
| 41 | +function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray) |
| 42 | + ... |
| 43 | + return y, foo_mul_pullback |
| 44 | +end |
| 45 | +``` |
| 46 | +The `rrule` dispatches on the `typeof` of the function we are writing the `rrule` for, as well as the types of its arguments. |
| 47 | +Read more about writing rules for constructors and callable objects [here](@ref structs). |
| 48 | +The `rrule` returns the primal result `y`, and the pullback function. |
| 49 | +It is a _very_ good idea to name your pullback function, so that they are helpful when appearing in the stacktrace. |
| 50 | +```julia |
| 51 | + y = foo_mul(foo, b) |
| 52 | +``` |
| 53 | +Computes the primal result. |
| 54 | +It is possible to change the primal computation so that work can be shared between the primal and the pullback. |
| 55 | +See e.g. [the rule for `sort`](https://github.com/JuliaDiff/ChainRules.jl/blob/a75193768775975fac5578c89d1e5f50d7f358c2/src/rulesets/Base/sort.jl#L19-L35), where the sorting is done only once. |
| 56 | +```julia |
| 57 | + function foo_mul_pullback(ȳ) |
| 58 | + ... |
| 59 | + return f̄, f̄oo, b̄ |
| 60 | + end |
| 61 | +``` |
| 62 | +The pullback function takes in the tangent of the primal output (`ȳ`) and returns the tangents of the primal inputs. |
| 63 | +Note that it returns a tangent for the primal function in addition to the tangents of primal arguments. |
| 64 | + |
| 65 | +Finally, computing the tangents of primal inputs: |
| 66 | +```julia |
| 67 | + f̄ = NoTangent() |
| 68 | +``` |
| 69 | +The function `foo_mul` has no fields (i.e. it is not a closure) and can not be perturbed. |
| 70 | +Therefore its tangent (`f̄`) is a `NoTangent`. |
| 71 | +```julia |
| 72 | + f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent()) |
| 73 | +``` |
| 74 | +The struct `foo::Foo` gets a `Tangent{Foo}` structural tangent, which stores the tangents of fields of `foo`. |
| 75 | + |
| 76 | +The tangent of the field `A` is `ȳ * b'`, |
| 77 | + |
| 78 | +The tangent of the field `c` is `ZeroTangent()`, because `c` can be perturbed but has no effect on the primal output. |
| 79 | +```julia |
| 80 | + b̄ = @thunk(foo.A' * ȳ) |
| 81 | +``` |
| 82 | +The tangent of `b` is `foo.A' * ȳ`, but we have wrapped it into a `Thunk`, a tangent type that represents delayed computation. |
| 83 | +The idea is that in case the tangent is not used anywhere, the computation never happens. |
| 84 | +Use [`InplaceableThunk`](@ref) if you are interested in [accumulating gradients inplace](@ref grad_acc). |
| 85 | +Note that in practice one would also `@thunk` the `f̄oo.A` tangent, but it was omitted in this example for clarity. |
| 86 | + |
| 87 | +As a final note, Since `b` is an `AbstractArray`, its tangent `b̄` should be projected to the right subspace. |
| 88 | +See the [`ProjectTo` the primal subspace](@ref projectto) section for more information and an example that motivates the projection operation. |
0 commit comments