|
| 1 | +# Gradient Accumulation |
| 2 | + |
| 3 | +Consider some function |
| 4 | +$$f(x) = g(x) + h(x)$$. |
| 5 | +If we would like the derivative of $f$ with respect to $x$ we must compute it for each part and then sum them, i.e. |
| 6 | +$$\frac{\partial f}{\partial x} = \frac{\partial g}{\partial x} + \frac{\partial h}{\partial x}$$. |
| 7 | +In general, we must accumulate (sum) gradients from each sub-part of a program where a variable is used. |
| 8 | + |
| 9 | + |
| 10 | +Consider for example: |
| 11 | +```julia |
| 12 | +function sum_first_and_second(X::Array{Float64}) |
| 13 | + a = X[1] |
| 14 | + b = X[2] |
| 15 | + y = a + b |
| 16 | + return y |
| 17 | +end |
| 18 | +``` |
| 19 | +The AD software must transform that into something which repeatedly sums up the gradient of each part: |
| 20 | +`X̄ = ā + b̄`. |
| 21 | + |
| 22 | +This requires that all differential types `D` must implement `+`: `+(::D, ::D)::D`. |
| 23 | + |
| 24 | +We can note that in this particular case `ā` and `b̄` will both be arrays. |
| 25 | +This operation (`X̄ = ā + b̄`) will allocate one array to hold `ā`, another one to hold `b̄`, and a third one to hold `ā + b̄`. |
| 26 | +This is three allocations. |
| 27 | +Allocations are not free, they increase the time the program takes to run by a nontrivial amount, even with a good allocator and a good garbage collector. |
| 28 | + |
| 29 | +### Maybe-mutating accumulation (`add!!`) |
| 30 | +We can note that in the above that neither `ā` nor `b̄` are ever used again after accumulating to get `X̄`. |
| 31 | +Furthermore, `Array`s are mutable. |
| 32 | +That means we could over-write either `ā` or `b̄` and use the result as `X̄`: |
| 33 | + |
| 34 | +```julia |
| 35 | +ā .+= b̄ |
| 36 | +X̄ = ā |
| 37 | +``` |
| 38 | + |
| 39 | +This cuts our allocations down to 2, just `ā` and `b̄`. |
| 40 | + |
| 41 | +However, we have a bit of a problem that not all types are mutable, so this pattern is hard to apply in general. |
| 42 | +To deal with that ChainRulesCore provides [`add!!`](@ref). |
| 43 | +Per the [BangBang.jl](https://github.com/JuliaFolds/BangBang.jl) convention, this is a maybe mutating addition. |
| 44 | +It may mutate its first argument (if it is mutable), but it will definitely return the correct result. |
| 45 | +We would write using that as `X̄ = add!!(ā, b̄)`: which would in this case give us just 2 allocations. |
| 46 | +AD systems can generate `add!!` instead of `+` when accumulating gradient to take advantage of this. |
| 47 | + |
| 48 | +### Inplaceable Thunks (`InplaceableThunks`) avoid allocating values in the first place. |
| 49 | +We got down to two allocations from using [`add!!`](@ref), but can we do better? |
| 50 | +We can think of having a differential type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated. |
| 51 | +Rather than having an actual computed value, we can just have a thing that will act on a value to perform the addition. |
| 52 | +Let's illustrate it with our example. |
| 53 | + |
| 54 | +`b̄` is the partial for `X[2]` and its value can be computed by: |
| 55 | + |
| 56 | +```julia |
| 57 | +b̄ = zeros(size(X)) |
| 58 | +b̄[2] = ȳ # the scalar sensitivity of the `mysum` output |
| 59 | +`b̄` is a matrix entirely of zeros, except for at the index `2`, where it is set to the output sensitivity `ȳ`. |
| 60 | +`ā` is similar, except with the non-zero at index `1`. |
| 61 | + |
| 62 | +What is the action of `b̄` upon `ā`, to get the same result as `X̄ = add!!(ā, b̄)` (or `X̄ = ā + b̄` for that matter)? |
| 63 | +It is: |
| 64 | + |
| 65 | +```julia |
| 66 | +function b̄_add!(ā) |
| 67 | + ā[2] += ȳ |
| 68 | + return ā |
| 69 | +end |
| 70 | +We don't need to worry about all those zeros since `x + 0 == x`. |
| 71 | +
|
| 72 | +[`InplaceableThunk`](@ref) is the type we have to represent derivatives as gradient accumulating actions. |
| 73 | +We must note that to do this we do need a value form of `ā` for `b̄` to act upon. |
| 74 | +For this reason every inplaceable thunk has both a `val` field holding the value representation, and a `add!` field holding the action representation. |
| 75 | +The `val` field use a plain [`Thunk`](@ref) to avoid the computation (and thus allocation) if it is unused. |
| 76 | +
|
| 77 | +!!! note "Do we need both representations?" |
| 78 | + Right now every [`InplaceableThunk`](@ref) has two fields that need to be specified. |
| 79 | + The value form (represented as a the [`Thunk`](@ref) typed field), and the action form (represented as the `add!` field). |
| 80 | + It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero differential for arbitrary primal values. |
| 81 | + Given that, we could always just determine the value form from `inplaceable.add!(zero_differential(primal))`. |
| 82 | + There are some technical difficulties in finding the zero differentials, but this may be solved at some point. |
| 83 | +
|
| 84 | +
|
| 85 | +The `+` operation on `InplaceableThunk`s is overloaded to [`unthunk`](@ref) that `val` field to get the value form. |
| 86 | +Where as the [`add!!`](@ref) operation is overloaded to call `add!` to invoke the action. |
| 87 | +
|
| 88 | +With `getindex` defined to return an `InplaceableThunk`, we now get to `X̄ = add!!(ā, b̄)` requires only a single allocation. |
| 89 | +This allocation occurs when `unthunk`ing `ā`, which is then mutated to become `X̄`. |
| 90 | +This is basically as good as we can get: if we want `X̄` to be an `Array` then at some point we need to allocate that array. |
| 91 | +
|
| 92 | +!!! note "Can we do more? Deferred accumulation" |
| 93 | + We could keep going further to drop allocations if we really wanted. |
| 94 | + If we didn't care about `X̄` being an `Array` then we could defer its computation too. |
| 95 | + `X̄ = @thunk add!!(ā, b̄)`. |
| 96 | + This kind of deferral will work fine and you can keep chaining it. |
| 97 | + It does start to burn stack space, and might make the compiler's optimization passes cry. |
| 98 | + But it's valid and should work fine. |
| 99 | +
|
| 100 | +### Examples of InplaceableThunks |
| 101 | +
|
| 102 | +#### `getindex` |
| 103 | +
|
| 104 | +The aforementioned `getindex` is really the poster child for this. |
| 105 | +Consider something like: |
| 106 | +```julia |
| 107 | +function mysum(X::Array{Float64}) |
| 108 | + total = 0.0 |
| 109 | + for i in eachindex(X) |
| 110 | + total += X[i] |
| 111 | + end |
| 112 | + return total |
| 113 | +end |
| 114 | +``` |
| 115 | +If one only has value representation of derivatives one ends up having to allocate a derivative array for every single element of the original array `X`. |
| 116 | +That's terrible. |
| 117 | +On the other hand, with the action representation that `InplaceableThunk`s provide, there is just a single `Array` allocated. |
| 118 | +One can see [the `getindex` rule in ChainRules.jl for the implementation](https://github.com/JuliaDiff/ChainRules.jl/blob/v0.7.49/src/rulesets/Base/indexing.jl). |
| 119 | + |
| 120 | + |
| 121 | +#### matmul etc (`*`) |
| 122 | +Multiplication of scalars/vectors/matrices of compatible dimensions can all also have their derivatives represented as an `InplaceableThunk`. |
| 123 | +These tend to pivot around that `add!` action being defined along the lines of: |
| 124 | +`X̄ -> mul!(X̄, A', Ȳ, true, true)`. |
| 125 | +Where 5-arg `mul!` is the in place [multiply-add operation](https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.mul!). |
| 126 | +`mul!(X̄, A', Ȳ, true, true)` has the same effect as `(X̄ .+= A'*Ȳ)` but avoids allocating the matrix `A'*Ȳ` |
| 127 | +This is one of the fundamental operations provided by BLAS -- including the application of the conjugate transpose. |
| 128 | +e.g. the Matrix-Matrix form is [`GEMM` (GEneralized Matrix-Matrix Multiplication)](http://www.netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html#gaeda3cbd99c8fb834a60a6412878226e1), |
| 129 | +the Matrix-Vector form is [`GEMV` (GEneralized Matrix-Vector Multiplication)](http://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_gadd421a107a488d524859b4a64c1901a9.html#gadd421a107a488d524859b4a64c1901a9) etc. |
| 130 | +Under the hood doing it out of place is going to call one of these methods anyway, but on a freshly allocated output array. |
| 131 | +So we are going to hit a very efficient implementation and get the addition for free. |
| 132 | + |
| 133 | + |
| 134 | +One can see [the `*` rules in ChainRules.jl for the implementations](https://github.com/JuliaDiff/ChainRules.jl/blob/v0.7.49/src/rulesets/Base/arraymath.jl#L22-L95) |
0 commit comments