Skip to content

Commit 2208660

Browse files
oxinaboxwilltebbuttmzgubic
authored
FAQ: What types does my pullback need to accept? (#428)
* FAQ: What types does my pullback need to accept? * Update docs/src/FAQ.md * move pullback types into writing good rules * mention that natural tangent does not have a formal definition Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> * Update docs/src/writing_good_rules.md * Apply suggestions from code review Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> * Update docs/src/writing_good_rules.md * Update docs/src/writing_good_rules.md Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
1 parent 09c133b commit 2208660

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

docs/src/FAQ.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ This is morally the same as similar issues [discussed in ColPrac](https://github
111111

112112
On a practical level, it's important that this is the case because thunks are a bit of a hack,
113113
and over time it is hoped that the need for them will reduce, as they increase
114-
code-complexity and place additional stress on the compiler.
114+
code-complexity and place additional stress on the compiler.

docs/src/writing_good_rules.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,71 @@ end
180180
```
181181
to define the rules.
182182

183+
## Ensure your pullback can accept the right types
184+
As a rule the number of types you need to accept in a pullback is theoretically unlimitted, but practically highly constrained to be in line with the primal return type.
185+
The three kinds of inputs you will practically need to accept one or more of: _natural tangents_, _structural tangents_, and _thunks_.
186+
You do not in general have to handle `AbstractZero`s as the AD system will not call the pullback if the input is a zero, since the output will also be.
187+
Some more background information on these types can be found in [the design notes](@ref manytypes).
188+
In many cases all these tangents can be treated the same: tangent types overload a bunch of linear-operators, and the majority of functions used inside a pullback are linear operators.
189+
If you find linear operators from Base/stdlibs that are not supported, consider opening an issue or a PR on the [ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/).
190+
191+
### Natural tangents
192+
Natural tangent types are the types you might feel the tangent should be, to represent a small change in the primal value.
193+
For example, if the primal is a `Float32`, the natural tangent is also a `Float32`.
194+
Slightly more complex, for a `ComplexF64` the natural tangent is again also a `ComplexF64`, we almost never want to use the structural tangent `Tangent{ComplexF64}(re=..., im=...)` which is defined.
195+
For other cases, this gets a little more complicated, see below.
196+
These are a purely human notion, they are the types the user wants to use because they make the math easy.
197+
There is currently no formal definition of what constitutes a natural tangent, but there are a few heuristics.
198+
For example, if a primal type `P` overloads subtraction (`-(::P,::P)`) then that generally returns a natural tangent type for `P`; but this is not required to be defined and sometimes it is defined poorly.
199+
200+
Common cases for types that represent a [vector-space](https://en.wikipedia.org/wiki/Vector_space) (e.g. `Float64`, `Array{Float64}`) is that the natural tangent type is the same as the primal type.
201+
However, this is not always the case.
202+
For example for a [`PDiagMat`](https://github.com/JuliaStats/PDMats.jl) a natural tangent is `Diagonal` since there is no requirement that a positive definite diagonal matrix has a positive definite tangent.
203+
Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural differential.
204+
There are often many different natural tangent types for a given primal type.
205+
However, they are generally closely related and duck-type the same.
206+
For example, for most `AbstractArray` subtypes, most other `AbstractArray`s (of right size and element type) can be considered as natural tangent types.
207+
208+
Not all types have natural tangent types.
209+
For example there is no natural differential for a `Tuple`.
210+
It is not a `Tuple` since that doesn't have any method for `+`.
211+
Similar is true for many `struct`s.
212+
For those cases there is only a structural differential.
213+
214+
### Structural tangents
215+
216+
Structural tangents are tangent types that shadow the structure of the primal type.
217+
They are represented by the [`Tangent`](@ref) type.
218+
They can represent any composite type, such as a tuple, or a structure (or a `NamedTuple`) etc.
219+
220+
221+
!!! info "Do I have to support the structural tangents as well?"
222+
Technically, you might not actually have to write rules to accept structural tangents; if the AD system never has to decompose down to the level of `getfield`.
223+
This is common for types that don't support user `getfield`/`getproperty` access, and that have a lot of rules for the ways they are accessed (such cases include some `AbstractArray` subtypes).
224+
You really should support it just in case; especially if the primal type in question is not restricted to a well-tested concrete type.
225+
But if it is causing struggles, then you can leave it off til someone complains.
226+
227+
### Thunks
228+
229+
A thunk (either a [`Thunk`](@ref), or a [`InplaceableThunk`](@ref)), represents a delayed computation.
230+
They can be thought of as a wrapper of the value the computation returns.
231+
In this sense they wrap either a natural or structural tangent.
232+
233+
!!! warning "You should to support AbstractThunk inputs even if you don't use thunks"
234+
Unfortunately the AD sytems do not know which rules support thunks and which do not.
235+
So all rules have to; at least if they want to play nice with arbitary AD systems.
236+
Luckily it is not hard: much of the time they will duck-type as the object they wrap.
237+
If not, then just add a [`unthunk`](@ref) after the start of your pullback.
238+
(Even when they do duck-type, if they are used multiple times then unthunking at the start will prevent them from being recomputed.)
239+
If you are using [`@thunk`](@ref) and the input is only needed for one of them then the `unthunk` should be in that one.
240+
If not, and you have a bunch of pullbacks you might like to write a little helper `unthunking(f) = x̄ -> f(unthunk(x̄))` that you can wrap your pullback function in before returning it from the `rrule`.
241+
Yes, this is a bit of boiler-plate, and it is unfortunate.
242+
Sadly, it is needed because if the AD wants to benefit it can't get that benifit unless things are not unthunked unnecessarily.
243+
Which eventually allows them in some cases to never be unthunked at all.
244+
There are two ways common things are never unthunked.
245+
One is if the unthunking happens inside a `@thunk` which is never unthunked itself because it is the tangent for a primal input that never has it's tangent queried.
246+
The second is if they are not unthunked because the rule does not need to know what is inside: consider the pullback for `identity`: `x̄ -> (NoTangent(), x̄)`.
247+
183248
## Use `@not_implemented` appropriately
184249

185250
One can use [`@not_implemented`](@ref) to mark missing differentials.

0 commit comments

Comments
 (0)