|
| 1 | +# Mutation Support |
| 2 | + |
| 3 | +ChainRulesCore.jl offers experimental support for mutation, targetting use in forward mode AD. |
| 4 | +(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) |
| 5 | + |
| 6 | +!!! warning "Experimental" |
| 7 | + This page documents an experimental feature. |
| 8 | + Expect breaking changes in minor versions while this remains. |
| 9 | + It is not suitable for general use unless you are prepared to modify how you are using it each minor release. |
| 10 | + It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions. |
| 11 | + |
| 12 | + |
| 13 | +## `MutableTangent` |
| 14 | +The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place. |
| 15 | +It is required to be a structural tangent, having one tangent for each field of the primal object. |
| 16 | + |
| 17 | +Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. |
| 18 | +Just like not all `struct`s need to use `Tangent`s. |
| 19 | +Common examples away from this are natural tangent types like for arrays. |
| 20 | +However, if one is setting up to use a custom tangent type for this it is surficiently off the beated path that we can not provide much guidance. |
| 21 | + |
| 22 | +## `zero_tangent` |
| 23 | + |
| 24 | +The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. |
| 25 | +The [`ZeroTangent`](@ref) type also does this. |
| 26 | +The difference is that [`zero_tangent`](@ref) is (where possible) a full structural tangent mirroring the structure of the primal. |
| 27 | +For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. |
| 28 | +To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). |
| 29 | + |
| 30 | +It is also useful for reasons of type stability, since it is always a structural tangent. |
| 31 | +For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not. |
| 32 | + |
| 33 | +## Writing a frule for a mutating function |
| 34 | +It is relatively straight forward to write a frule for a mutating function. |
| 35 | +There are a few key points to follow: |
| 36 | + - There must be a mutable tangent input for every mutated primal input |
| 37 | + - When the primal value is changed, the corresponding change must be made to its tangent partner |
| 38 | + - When a value is returned, return its partnered tangent. |
| 39 | + |
| 40 | + |
| 41 | +### Example |
| 42 | +For example, consider the primal function with: |
| 43 | +1. takes two `Ref`s |
| 44 | +2. doubles the first one inplace |
| 45 | +3. overwrites the second one's value with the literal 5.0 |
| 46 | +4. returns the first one |
| 47 | + |
| 48 | + |
| 49 | +```julia |
| 50 | +function foo!(a::Base.RefValue, b::Base.RefValue) |
| 51 | + a[] *= 2 |
| 52 | + b[] = 5.0 |
| 53 | + return a |
| 54 | +end |
| 55 | +``` |
| 56 | + |
| 57 | +The frule for this would be: |
| 58 | +```julia |
| 59 | +function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) |
| 60 | + @assert ȧ isa MutableTangent{typeof(a)} |
| 61 | + @assert ḃ isa MutableTangent{typeof(b)} |
| 62 | + |
| 63 | + a[] *= 2 |
| 64 | + ȧ.x *= 2 # `.x` is the field that lives behind RefValues |
| 65 | + |
| 66 | + b[]=5.0 |
| 67 | + ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0` |
| 68 | + |
| 69 | + return a, ȧ |
| 70 | +end |
| 71 | +``` |
| 72 | + |
| 73 | +Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. |
0 commit comments