Skip to content

Commit f054f16

Browse files
oxinaboxmzgubicniklasschmitzmattBrzezinski
authored
Document about Gradient Accumulation (#287)
* Document about Gradient Accumulation * fix filename * lock the version of julia we build docs on * heading level * Apply suggestions from code review Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> * Update docs/src/gradient_accumulation.md Co-authored-by: Niklas Schmitz <niklas.f.schmitz@gmail.com> * Update docs/src/gradient_accumulation.md * Apply suggestions from code review Co-authored-by: mattBrzezinski <matt.brzezinski@invenia.ca> Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com> Co-authored-by: Niklas Schmitz <niklas.f.schmitz@gmail.com> Co-authored-by: mattBrzezinski <matt.brzezinski@invenia.ca>
1 parent f1fbb4c commit f054f16

File tree

4 files changed

+137
-2
lines changed

4 files changed

+137
-2
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ jobs:
8282
- uses: actions/checkout@v2
8383
- uses: julia-actions/setup-julia@v1
8484
with:
85-
version: '1'
85+
version: '1.5'
8686
- run: |
8787
julia --project=docs -e '
8888
using Pkg

docs/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ version = "0.5.10"
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.26"
16+
version = "0.9.27"
1717

1818
[[Compat]]
1919
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ makedocs(
3737
"Complex Numbers" => "complex.md",
3838
"Deriving Array Rules" => "arrays.md",
3939
"Debug Mode" => "debug_mode.md",
40+
"Gradient Accumulation" => "gradient_accumulation.md",
4041
"Usage in AD" => [
4142
"Overview" => "autodiff/overview.md",
4243
"Operator Overloading" => "autodiff/operator_overloading.md"

docs/src/gradient_accumulation.md

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Gradient Accumulation
2+
3+
Consider some function
4+
$$f(x) = g(x) + h(x)$$.
5+
If we would like the derivative of $f$ with respect to $x$ we must compute it for each part and then sum them, i.e.
6+
$$\frac{\partial f}{\partial x} = \frac{\partial g}{\partial x} + \frac{\partial h}{\partial x}$$.
7+
In general, we must accumulate (sum) gradients from each sub-part of a program where a variable is used.
8+
9+
10+
Consider for example:
11+
```julia
12+
function sum_first_and_second(X::Array{Float64})
13+
a = X[1]
14+
b = X[2]
15+
y = a + b
16+
return y
17+
end
18+
```
19+
The AD software must transform that into something which repeatedly sums up the gradient of each part:
20+
`X̄ = ā + b̄`.
21+
22+
This requires that all differential types `D` must implement `+`: `+(::D, ::D)::D`.
23+
24+
We can note that in this particular case `` and `` will both be arrays.
25+
This operation (`X̄ = ā + b̄`) will allocate one array to hold `ā`, another one to hold ``, and a third one to hold `ā + b̄`.
26+
This is three allocations.
27+
Allocations are not free, they increase the time the program takes to run by a nontrivial amount, even with a good allocator and a good garbage collector.
28+
29+
### Maybe-mutating accumulation (`add!!`)
30+
We can note that in the above that neither `` nor `` are ever used again after accumulating to get ``.
31+
Furthermore, `Array`s are mutable.
32+
That means we could over-write either `` or `` and use the result as ``:
33+
34+
```julia
35+
.+=
36+
=
37+
```
38+
39+
This cuts our allocations down to 2, just `` and ``.
40+
41+
However, we have a bit of a problem that not all types are mutable, so this pattern is hard to apply in general.
42+
To deal with that ChainRulesCore provides [`add!!`](@ref).
43+
Per the [BangBang.jl](https://github.com/JuliaFolds/BangBang.jl) convention, this is a maybe mutating addition.
44+
It may mutate its first argument (if it is mutable), but it will definitely return the correct result.
45+
We would write using that as `X̄ = add!!(ā, b̄)`: which would in this case give us just 2 allocations.
46+
AD systems can generate `add!!` instead of `+` when accumulating gradient to take advantage of this.
47+
48+
### Inplaceable Thunks (`InplaceableThunks`) avoid allocating values in the first place.
49+
We got down to two allocations from using [`add!!`](@ref), but can we do better?
50+
We can think of having a differential type which acts on a partially accumulated result, to mutate it to contain its current value plus the partial derivative being accumulated.
51+
Rather than having an actual computed value, we can just have a thing that will act on a value to perform the addition.
52+
Let's illustrate it with our example.
53+
54+
`` is the partial for `X[2]` and its value can be computed by:
55+
56+
```julia
57+
= zeros(size(X))
58+
b̄[2] =# the scalar sensitivity of the `mysum` output
59+
`` is a matrix entirely of zeros, except for at the index `2`, where it is set to the output sensitivity ``.
60+
`` is similar, except with the non-zero at index `1`.
61+
62+
What is the action of `` upon ``, to get the same result as `X̄ = add!!(ā, b̄)` (or `X̄ = ā + b̄` for that matter)?
63+
It is:
64+
65+
```julia
66+
function b̄_add!(ā)
67+
ā[2] += ȳ
68+
return ā
69+
end
70+
We don't need to worry about all those zeros since `x + 0 == x`.
71+
72+
[`InplaceableThunk`](@ref) is the type we have to represent derivatives as gradient accumulating actions.
73+
We must note that to do this we do need a value form of `ā` for `b̄` to act upon.
74+
For this reason every inplaceable thunk has both a `val` field holding the value representation, and a `add!` field holding the action representation.
75+
The `val` field use a plain [`Thunk`](@ref) to avoid the computation (and thus allocation) if it is unused.
76+
77+
!!! note "Do we need both representations?"
78+
Right now every [`InplaceableThunk`](@ref) has two fields that need to be specified.
79+
The value form (represented as a the [`Thunk`](@ref) typed field), and the action form (represented as the `add!` field).
80+
It is possible in a future version of ChainRulesCore.jl we will work out a clever way to find the zero differential for arbitrary primal values.
81+
Given that, we could always just determine the value form from `inplaceable.add!(zero_differential(primal))`.
82+
There are some technical difficulties in finding the zero differentials, but this may be solved at some point.
83+
84+
85+
The `+` operation on `InplaceableThunk`s is overloaded to [`unthunk`](@ref) that `val` field to get the value form.
86+
Where as the [`add!!`](@ref) operation is overloaded to call `add!` to invoke the action.
87+
88+
With `getindex` defined to return an `InplaceableThunk`, we now get to `X̄ = add!!(ā, b̄)` requires only a single allocation.
89+
This allocation occurs when `unthunk`ing `ā`, which is then mutated to become `X̄`.
90+
This is basically as good as we can get: if we want `X̄` to be an `Array` then at some point we need to allocate that array.
91+
92+
!!! note "Can we do more? Deferred accumulation"
93+
We could keep going further to drop allocations if we really wanted.
94+
If we didn't care about `X̄` being an `Array` then we could defer its computation too.
95+
`X̄ = @thunk add!!(ā, b̄)`.
96+
This kind of deferral will work fine and you can keep chaining it.
97+
It does start to burn stack space, and might make the compiler's optimization passes cry.
98+
But it's valid and should work fine.
99+
100+
### Examples of InplaceableThunks
101+
102+
#### `getindex`
103+
104+
The aforementioned `getindex` is really the poster child for this.
105+
Consider something like:
106+
```julia
107+
function mysum(X::Array{Float64})
108+
total = 0.0
109+
for i in eachindex(X)
110+
total += X[i]
111+
end
112+
return total
113+
end
114+
```
115+
If one only has value representation of derivatives one ends up having to allocate a derivative array for every single element of the original array `X`.
116+
That's terrible.
117+
On the other hand, with the action representation that `InplaceableThunk`s provide, there is just a single `Array` allocated.
118+
One can see [the `getindex` rule in ChainRules.jl for the implementation](https://github.com/JuliaDiff/ChainRules.jl/blob/v0.7.49/src/rulesets/Base/indexing.jl).
119+
120+
121+
#### matmul etc (`*`)
122+
Multiplication of scalars/vectors/matrices of compatible dimensions can all also have their derivatives represented as an `InplaceableThunk`.
123+
These tend to pivot around that `add!` action being defined along the lines of:
124+
`X̄ -> mul!(X̄, A', Ȳ, true, true)`.
125+
Where 5-arg `mul!` is the in place [multiply-add operation](https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.mul!).
126+
`mul!(X̄, A', Ȳ, true, true)` has the same effect as `(X̄ .+= A'*Ȳ)` but avoids allocating the matrix `A'*Ȳ`
127+
This is one of the fundamental operations provided by BLAS -- including the application of the conjugate transpose.
128+
e.g. the Matrix-Matrix form is [`GEMM` (GEneralized Matrix-Matrix Multiplication)](http://www.netlib.org/lapack/explore-html/d1/d54/group__double__blas__level3_gaeda3cbd99c8fb834a60a6412878226e1.html#gaeda3cbd99c8fb834a60a6412878226e1),
129+
the Matrix-Vector form is [`GEMV` (GEneralized Matrix-Vector Multiplication)](http://www.netlib.org/lapack/explore-html/d7/d15/group__double__blas__level2_gadd421a107a488d524859b4a64c1901a9.html#gadd421a107a488d524859b4a64c1901a9) etc.
130+
Under the hood doing it out of place is going to call one of these methods anyway, but on a freshly allocated output array.
131+
So we are going to hit a very efficient implementation and get the addition for free.
132+
133+
134+
One can see [the `*` rules in ChainRules.jl for the implementations](https://github.com/JuliaDiff/ChainRules.jl/blob/v0.7.49/src/rulesets/Base/arraymath.jl#L22-L95)

0 commit comments

Comments
 (0)