Skip to content

Commit 8b3d525

Browse files
committed
Add docs for forward mutation support
1 parent ad9a5af commit 8b3d525

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ makedocs(;
6161
"`@opt_out`" => "rule_author/superpowers/opt_out.md",
6262
"`RuleConfig`" => "rule_author/superpowers/ruleconfig.md",
6363
"Gradient accumulation" => "rule_author/superpowers/gradient_accumulation.md",
64+
"Mutation Support (experimental)" => "rule_author/superpowers/mutation_support.md",
6465
],
6566
"Converting ZygoteRules.@adjoint to rrules" => "rule_author/converting_zygoterules.md",
6667
"Tips for making your package work with AD" => "rule_author/tips_for_packages.md",

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Modules = [ChainRulesCore]
2020
Pages = [
2121
"tangent_types/abstract_zero.jl",
2222
"tangent_types/one.jl",
23-
"tangent_types/tangent.jl",
23+
"tangent_types/structural_tangent.jl",
2424
"tangent_types/thunks.jl",
2525
"tangent_types/abstract_tangent.jl",
2626
"tangent_types/notimplemented.jl",
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Mutation Support
2+
3+
ChainRulesCore.jl offers experimental support for mutation, targetting use in forward mode AD.
4+
(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface)
5+
6+
!!! warning "Experimental"
7+
This page documents an experimental feature.
8+
Expect breaking changes in minor versions while this remains.
9+
It is not suitable for general use unless you are prepared to modify how you are using it each minor release.
10+
It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions.
11+
12+
13+
## `MutableTangent`
14+
The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place.
15+
It is required to be a structural tangent, having one tangent for each field of the primal object.
16+
17+
Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents.
18+
Just like not all `struct`s need to use `Tangent`s.
19+
Common examples away from this are natural tangent types like for arrays.
20+
However, if one is setting up to use a custom tangent type for this it is surficiently off the beated path that we can not provide much guidance.
21+
22+
## `zero_tangent`
23+
24+
The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value.
25+
The [`ZeroTangent`](@ref) type also does this.
26+
The difference is that [`zero_tangent`](@ref) is (where possible) a full structural tangent mirroring the structure of the primal.
27+
For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes.
28+
To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref).
29+
30+
It is also useful for reasons of type stability, since it is always a structural tangent.
31+
For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not.
32+
33+
## Writing a frule for a mutating function
34+
It is relatively straight forward to write a frule for a mutating function.
35+
There are a few key points to follow:
36+
- There must be a mutable tangent input for every mutated primal input
37+
- When the primal value is changed, the corresponding change must be made to its tangent partner
38+
- When a value is returned, return its partnered tangent.
39+
40+
41+
### Example
42+
For example, consider the primal function with:
43+
1. takes two `Ref`s
44+
2. doubles the first one inplace
45+
3. overwrites the second one's value with the literal 5.0
46+
4. returns the first one
47+
48+
49+
```julia
50+
function foo!(a::Base.RefValue, b::Base.RefValue)
51+
a[] *= 2
52+
b[] = 5.0
53+
return a
54+
end
55+
```
56+
57+
The frule for this would be:
58+
```julia
59+
function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue)
60+
@assertisa MutableTangent{typeof(a)}
61+
@assertisa MutableTangent{typeof(b)}
62+
63+
a[] *= 2
64+
.x *= 2 # `.x` is the field that lives behind RefValues
65+
66+
b[]=5.0
67+
.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0`
68+
69+
return a, ȧ
70+
end
71+
```
72+
73+
Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct.

0 commit comments

Comments
 (0)