Skip to content

Commit 2ec2549

Browse files
st--mzgubic
andauthored
Clarify "writing good rules" documentation (#468)
* Clarify order of @thunk and ProjectTo in docs * Apply suggestions from code review Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> * incorporate @oxinabox's draft * move paragraph to right (?) place * add comment on test_[fr]rule of @not_implemented differentials * typo fix * improve example Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
1 parent 084b5e7 commit 2ec2549

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

docs/src/writing_good_rules.md

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ julia> rrule(foo, 2)
2828
==#
2929
```
3030

31+
While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it.
32+
This makes it a lot simpler to debug from the stacktrace.
33+
3134
## Use `ZeroTangent()` as the return value
3235

3336
The `ZeroTangent()` object exists as an alternative to directly returning `0` or `zeros(n)`.
@@ -102,6 +105,18 @@ function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
102105
end
103106
```
104107

108+
!!! note "It is often good to `@thunk` your projections"
109+
The above example is potentially a good place for using a [`@thunk`](@ref).
110+
This is not required, but can in some cases be more computationally efficient, see [Use `Thunk`s appropriately](@ref).
111+
When combining thunks and projections, `@thunk()` must be the outermost call.
112+
113+
A more optimized implementation of the matrix-matrix multiplication example would have
114+
```julia
115+
times_pullback(ȳ) = NoTangent(), @thunk(project_A(ȳ * B')), @thunk(project_B(A' * ȳ))
116+
```
117+
within the `rrule`. This defers both the evaluation of the product rule and
118+
the projection until(/if) the tangent gets used.
119+
105120
## Structs: constructors and functors
106121

107122
To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.
@@ -230,9 +245,9 @@ A thunk (either a [`Thunk`](@ref), or a [`InplaceableThunk`](@ref)), represents
230245
They can be thought of as a wrapper of the value the computation returns.
231246
In this sense they wrap either a natural or structural tangent.
232247

233-
!!! warning "You should to support AbstractThunk inputs even if you don't use thunks"
248+
!!! warning "You should support AbstractThunk inputs even if you don't use thunks"
234249
Unfortunately the AD sytems do not know which rules support thunks and which do not.
235-
So all rules have to; at least if they want to play nice with arbitary AD systems.
250+
So all rules have to; at least if they want to play nicely with arbitrary AD systems.
236251
Luckily it is not hard: much of the time they will duck-type as the object they wrap.
237252
If not, then just add a [`unthunk`](@ref) after the start of your pullback.
238253
(Even when they do duck-type, if they are used multiple times then unthunking at the start will prevent them from being recomputed.)
@@ -247,23 +262,22 @@ In this sense they wrap either a natural or structural tangent.
247262

248263
## Use `@not_implemented` appropriately
249264

250-
One can use [`@not_implemented`](@ref) to mark missing differentials.
251-
This is helpful if the function has multiple inputs or outputs, and one has worked out analytically and implemented some but not all differentials.
265+
You can use [`@not_implemented`](@ref) to mark missing differentials.
266+
This is helpful if the function has multiple inputs or outputs, and you have worked out analytically and implemented some but not all differentials.
252267

253268
It is recommended to include a link to a GitHub issue about the missing differential in the debugging information:
254269
```julia
255270
@not_implemented(
256-
"""
257-
derivatives of Bessel functions with respect to the order are not implemented:
258-
https://github.com/JuliaMath/SpecialFunctions.jl/issues/160
259-
"""
271+
"""
272+
derivatives of Bessel functions with respect to the order are not implemented:
273+
https://github.com/JuliaMath/SpecialFunctions.jl/issues/160
274+
"""
260275
)
261276
```
262277

263278
Do not use `@not_implemented` if the differential does not exist mathematically (use `NoTangent()` instead).
264279

265-
While this is more verbose, it ensures that if an error is thrown during the `pullback` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it.
266-
This makes it a lot simpler to debug from the stacktrace.
280+
Note: [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) marks `@not_implemented` differentials as "test broken".
267281

268282
## Use rule definition tools
269283

@@ -387,7 +401,7 @@ Take a look at the documentation or the existing [ChainRules.jl](https://github.
387401
388402
!!! warning
389403
Don't use analytical derivations for derivatives in the tests.
390-
Those are what you use to define the rules, and so can not be confidently used in the test.
404+
Those are what you use to define the rules, and so cannot be confidently used in the test.
391405
If you misread/misunderstood them, then your tests/implementation will have the same mistake.
392406
Use finite differencing methods instead, as they are based on the primal computation.
393407
@@ -401,10 +415,10 @@ In principle, a perfect AD system only needs rules for basic operations and can
401415
In practice, performance needs to be considered as well.
402416
403417
Some functions use `ccall` internally, for example [`^`](https://github.com/JuliaLang/julia/blob/v1.5.3/base/math.jl#L886).
404-
These functions can not be differentiated through by AD systems, and need custom rules.
418+
These functions cannot be differentiated through by AD systems, and need custom rules.
405419
406420
Other functions can in principle be differentiated through by an AD system, but there exists a mathematical insight that can dramatically improve the computation of the derivative.
407-
An example is numerical integration, where writing a rule removes the need to perform AD through numerical integration.
421+
An example is numerical integration, where writing a rule implementing the [fundamental theorem of calculus](https://en.wikipedia.org/wiki/Fundamental_theorem_of_calculus) removes the need to perform AD through numerical integration.
408422
409423
Furthermore, AD systems make different trade-offs in performance due to their design.
410424
This means that a certain rule will help one AD system, but not improve (and also not harm) another.
@@ -416,7 +430,7 @@ This may be resolved in the future by [allowing AD systems to opt-in or opt-out
416430
417431
### Patterns that need rules in [Zygote.jl](https://github.com/FluxML/Zygote.jl)
418432
419-
There are a few classes of functions that Zygote can not differentiate through.
433+
There are a few classes of functions that Zygote cannot differentiate through.
420434
Custom rules will need to be written for these to make AD work.
421435
422436
Other patterns can be AD'ed through, but the backward pass performance can be greatly improved by writing a rule.

0 commit comments

Comments
 (0)