|
180 | 180 | ```
|
181 | 181 | to define the rules.
|
182 | 182 |
|
| 183 | +## Ensure your pullback can accept the right types |
| 184 | +As a rule the number of types you need to accept in a pullback is theoretically unlimitted, but practically highly constrained to be in line with the primal return type. |
| 185 | +The three kinds of inputs you will practically need to accept one or more of: _natural tangents_, _structural tangents_, and _thunks_. |
| 186 | +You do not in general have to handle `AbstractZero`s as the AD system will not call the pullback if the input is a zero, since the output will also be. |
| 187 | +Some more background information on these types can be found in [the design notes](@ref manytypes). |
| 188 | +In many cases all these tangents can be treated the same: tangent types overload a bunch of linear-operators, and the majority of functions used inside a pullback are linear operators. |
| 189 | +If you find linear operators from Base/stdlibs that are not supported, consider opening an issue or a PR on the [ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/). |
| 190 | + |
| 191 | +### Natural tangents |
| 192 | +Natural tangent types are the types you might feel the tangent should be, to represent a small change in the primal value. |
| 193 | +For example, if the primal is a `Float32`, the natural tangent is also a `Float32`. |
| 194 | +Slightly more complex, for a `ComplexF64` the natural tangent is again also a `ComplexF64`, we almost never want to use the structural tangent `Tangent{ComplexF64}(re=..., im=...)` which is defined. |
| 195 | +For other cases, this gets a little more complicated, see below. |
| 196 | +These are a purely human notion, they are the types the user wants to use because they make the math easy. |
| 197 | +There is currently no formal definition of what constitutes a natural tangent, but there are a few heuristics. |
| 198 | +For example, if a primal type `P` overloads subtraction (`-(::P,::P)`) then that generally returns a natural tangent type for `P`; but this is not required to be defined and sometimes it is defined poorly. |
| 199 | + |
| 200 | +Common cases for types that represent a [vector-space](https://en.wikipedia.org/wiki/Vector_space) (e.g. `Float64`, `Array{Float64}`) is that the natural tangent type is the same as the primal type. |
| 201 | +However, this is not always the case. |
| 202 | +For example for a [`PDiagMat`](https://github.com/JuliaStats/PDMats.jl) a natural tangent is `Diagonal` since there is no requirement that a positive definite diagonal matrix has a positive definite tangent. |
| 203 | +Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural differential. |
| 204 | +There are often many different natural tangent types for a given primal type. |
| 205 | +However, they are generally closely related and duck-type the same. |
| 206 | +For example, for most `AbstractArray` subtypes, most other `AbstractArray`s (of right size and element type) can be considered as natural tangent types. |
| 207 | + |
| 208 | +Not all types have natural tangent types. |
| 209 | +For example there is no natural differential for a `Tuple`. |
| 210 | +It is not a `Tuple` since that doesn't have any method for `+`. |
| 211 | +Similar is true for many `struct`s. |
| 212 | +For those cases there is only a structural differential. |
| 213 | + |
| 214 | +### Structural tangents |
| 215 | + |
| 216 | +Structural tangents are tangent types that shadow the structure of the primal type. |
| 217 | +They are represented by the [`Tangent`](@ref) type. |
| 218 | +They can represent any composite type, such as a tuple, or a structure (or a `NamedTuple`) etc. |
| 219 | + |
| 220 | + |
| 221 | +!!! info "Do I have to support the structural tangents as well?" |
| 222 | + Technically, you might not actually have to write rules to accept structural tangents; if the AD system never has to decompose down to the level of `getfield`. |
| 223 | + This is common for types that don't support user `getfield`/`getproperty` access, and that have a lot of rules for the ways they are accessed (such cases include some `AbstractArray` subtypes). |
| 224 | + You really should support it just in case; especially if the primal type in question is not restricted to a well-tested concrete type. |
| 225 | + But if it is causing struggles, then you can leave it off til someone complains. |
| 226 | + |
| 227 | +### Thunks |
| 228 | + |
| 229 | +A thunk (either a [`Thunk`](@ref), or a [`InplaceableThunk`](@ref)), represents a delayed computation. |
| 230 | +They can be thought of as a wrapper of the value the computation returns. |
| 231 | +In this sense they wrap either a natural or structural tangent. |
| 232 | + |
| 233 | +!!! warning "You should to support AbstractThunk inputs even if you don't use thunks" |
| 234 | + Unfortunately the AD sytems do not know which rules support thunks and which do not. |
| 235 | + So all rules have to; at least if they want to play nice with arbitary AD systems. |
| 236 | + Luckily it is not hard: much of the time they will duck-type as the object they wrap. |
| 237 | + If not, then just add a [`unthunk`](@ref) after the start of your pullback. |
| 238 | + (Even when they do duck-type, if they are used multiple times then unthunking at the start will prevent them from being recomputed.) |
| 239 | + If you are using [`@thunk`](@ref) and the input is only needed for one of them then the `unthunk` should be in that one. |
| 240 | + If not, and you have a bunch of pullbacks you might like to write a little helper `unthunking(f) = x̄ -> f(unthunk(x̄))` that you can wrap your pullback function in before returning it from the `rrule`. |
| 241 | + Yes, this is a bit of boiler-plate, and it is unfortunate. |
| 242 | + Sadly, it is needed because if the AD wants to benefit it can't get that benifit unless things are not unthunked unnecessarily. |
| 243 | + Which eventually allows them in some cases to never be unthunked at all. |
| 244 | + There are two ways common things are never unthunked. |
| 245 | + One is if the unthunking happens inside a `@thunk` which is never unthunked itself because it is the tangent for a primal input that never has it's tangent queried. |
| 246 | + The second is if they are not unthunked because the rule does not need to know what is inside: consider the pullback for `identity`: `x̄ -> (NoTangent(), x̄)`. |
| 247 | + |
183 | 248 | ## Use `@not_implemented` appropriately
|
184 | 249 |
|
185 | 250 | One can use [`@not_implemented`](@ref) to mark missing differentials.
|
|
0 commit comments