Skip to content

Commit cb76e9d

Browse files
authored
GPU docs (#2510)
1 parent c16291f commit cb76e9d

File tree

1 file changed

+110
-95
lines changed

1 file changed

+110
-95
lines changed

docs/src/guide/gpu.md

Lines changed: 110 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,135 +1,150 @@
11
# GPU Support
22

3-
Starting with v0.14, Flux doesn't force a specific GPU backend and the corresponding package dependencies on the users.
4-
Thanks to the [package extension mechanism](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) introduced in julia v1.9, Flux conditionally loads GPU specific code once a GPU package is made available (e.g. through `using CUDA`).
3+
Most work on neural networks involves the use of GPUs, as they can typically perform the required computation much faster.
4+
This page describes how Flux co-operates with various other packages, which talk to GPU hardware.
55

6-
NVIDIA GPU support requires the packages `CUDA.jl` and `cuDNN.jl` to be installed in the environment. In the julia REPL, type `] add CUDA, cuDNN` to install them. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme.
6+
## Basic GPU use: from `Array` to `CuArray` with `cu`
77

8-
AMD GPU support is available since Julia 1.9 on systems with ROCm and MIOpen installed. For more details refer to the [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) repository.
8+
Julia's GPU packages work with special array types, in place of the built-in `Array`.
9+
The most used is `CuArray` provided by [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), for GPUs made by NVIDIA.
10+
That package provides a function `cu` which converts an ordinary `Array` (living in CPu memory) to a `CuArray` (living in GPU memory).
11+
Functions like `*` and broadcasting specialise so that, when given `CuArray`s, all the computation happens on the GPU:
912

10-
Metal GPU acceleration is available on Apple Silicon hardware. For more details refer to the [Metal.jl](https://github.com/JuliaGPU/Metal.jl) repository. Metal support in Flux is experimental and many features are not yet available.
11-
12-
In order to trigger GPU support in Flux, you need to call `using CUDA`, `using AMDGPU` or `using Metal`
13-
in your code. Notice that for CUDA, explicitly loading also `cuDNN` is not required, but the package has to be installed in the environment.
13+
```julia
14+
W = randn(3, 4) # some weights, on CPU: 3×4 Array{Float64, 2}
15+
x = randn(4) # fake data
16+
y = tanh.(W * x) # computation on the CPU
1417

15-
## Basic GPU Usage
18+
using CUDA
1619

17-
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), and [Metal.jl](https://github.com/JuliaGPU/Metal.jl).
18-
Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
20+
cu(W) isa CuArray{Float32}
21+
(cW, cx) = (W, x) |> cu # move both to GPU
22+
cy = tanh.(cW * cx) # computation on the GPU
23+
```
1924

20-
For example, we can use `CUDA.CuArray` (with the `CUDA.cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU.
25+
Notice that `cu` doesn't only move arrays, it also recurses into many structures, such as the tuple `(W, x)` above.
26+
(Notice also that it converts Julia's default `Float64` numbers to `Float32`, as this is what most GPUs support efficiently -- it calls itself "opinionated". Flux defaults to `Float32` in all cases.)
2127

22-
(Note that you need to have CUDA available to use CUDA.CuArray – please see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) instructions for more details.)
28+
To use CUDA with Flux, you can simply use `cu` to move both the model, and the data.
29+
It will create a copy of the Flux model, with all of its parameter arrays moved to the GPU:
2330

2431
```julia
25-
using CUDA
32+
using Pkg; Pkg.add(["CUDA", "cuDNN"]) # do this once
2633

27-
W = cu(rand(2, 5)) # a 2×5 CuArray
28-
b = cu(rand(2))
34+
using Flux, CUDA
35+
CUDA.allowscalar(false) # recommended
2936

30-
predict(x) = W*x .+ b
31-
loss(x, y) = sum((predict(x) .- y).^2)
37+
model = Dense(W, true, tanh) # wrap the same matrix W in a Flux layer
38+
model(x) y # same result, still on CPU
3239

33-
x, y = cu(rand(5)), cu(rand(2)) # Dummy data
34-
loss(x, y) # ~ 3
40+
c_model = cu(model) # move all the arrays within model to the GPU
41+
c_model(cx) # computation on the GPU
3542
```
3643

37-
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
44+
Notice that you need `using CUDA` (every time) but also `] add cuDNN` (once, when installing packages).
45+
This is a quirk of how these packages are set up.
46+
(The [`cuDNN.jl`](https://github.com/JuliaGPU/CUDA.jl/tree/master/lib/cudnn) sub-package handles operations such as convolutions, called by Flux via [NNlib.jl](https://github.com/FluxML/NNlib.jl).)
3847

39-
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `fmap`, which allows you to alter all parameters of a model at once.
48+
Flux's `gradient`, and training functions like `setup`, `update!`, and `train!`, are all equally happy to accept GPU arrays and GPU models, and then perform all computations on the GPU.
49+
It is recommended that you move the model to the GPU before calling `setup`.
4050

4151
```julia
42-
d = Dense(10 => 5, σ)
43-
d = fmap(cu, d)
44-
d.weight # CuArray
45-
d(cu(rand(10))) # CuArray output
46-
47-
m = Chain(Dense(10 => 5, σ), Dense(5 => 2), softmax)
48-
m = fmap(cu, m)
49-
m(cu(rand(10)))
52+
grads = Flux.gradient((f,x) -> sum(abs2, f(x)), model, x) # on CPU
53+
c_grads = Flux.gradient((f,x) -> sum(abs2, f(x)), c_model, cx) # same result, all on GPU
54+
55+
c_opt = Flux.setup(Adam(), c_model) # setup optimiser after moving model to GPU
56+
57+
Flux.update!(c_opt, c_model, c_grads[1]) # mutates c_model but not model
5058
```
5159

52-
As a convenience, Flux provides the `gpu` function to convert models and data to the GPU if one is available. By default, it'll do nothing. So, you can safely call `gpu` on some data or model (as shown below), and the code will not error, regardless of whether the GPU is available or not. If a GPU library (e.g. CUDA) loads successfully, `gpu` will move data from the CPU to the GPU. As is shown below, this will change the type of something like a regular array to a `CuArray`.
60+
To move arrays and other objects back to the CPU, Flux provides a function `cpu`.
61+
This is recommended when saving models, `Flux.state(c_model |> cpu)`, see below.
5362

5463
```julia
55-
julia> using Flux, CUDA
56-
57-
julia> m = Dense(10, 5) |> gpu
58-
Dense(10 => 5) # 55 parameters
59-
60-
julia> x = rand(10) |> gpu
61-
10-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
62-
0.066846445
63-
64-
0.76706964
65-
66-
julia> m(x)
67-
5-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
68-
-0.99992573
69-
70-
-0.547261
64+
cpu(cW) isa Array{Float32, 2}
65+
66+
model2 = cpu(c_model) # copy model back to CPU
67+
model2(x)
7168
```
7269

73-
The analogue `cpu` is also available for moving models and data back off of the GPU.
70+
!!! compat "Flux ≤ 0.13"
71+
Old versions of Flux automatically loaded CUDA.jl to provide GPU support. Starting from Flux v0.14, it has to be loaded separately. Julia's [package extensions](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) allow Flux to automatically load some GPU-specific code when needed.
72+
73+
## Other GPU packages for AMD & Apple
74+
75+
Non-NVIDIA graphics cards are supported by other packages. Each provides its own function which behaves like `cu`.
76+
AMD GPU support provided by [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), on systems with ROCm and MIOpen installed.
77+
This package has a function `roc` which converts `Array` to `ROCArray`:
7478

7579
```julia
76-
julia> x = rand(10) |> gpu
77-
10-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
78-
0.8019236
79-
80-
0.7766742
81-
82-
julia> x |> cpu
83-
10-element Vector{Float32}:
84-
0.8019236
85-
86-
0.7766742
87-
```
80+
using Flux, AMDGPU
81+
AMDGPU.allowscalar(false)
8882

89-
## Using device objects
83+
r_model = roc(model)
84+
r_model(roc(x))
9085

91-
In Flux, you can create `device` objects which can be used to easily transfer models and data to GPUs (and defaulting to using the CPU if no GPU backend is available).
92-
These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux uses internally and re-exports.
86+
Flux.gradient((f,x) -> sum(abs2, f(x)), r_model, roc(x))
87+
```
9388

94-
Device objects can be automatically created using the [`cpu_device`](@ref MLDataDevices.cpu_device) and [`gpu_device`](@ref MLDataDevices.gpu_device) functions. For instance, the `gpu` and `cpu` functions are just convenience functions defined as
89+
Experimental support for Apple devices with M-series chips is provided by [Metal.jl](https://github.com/JuliaGPU/Metal.jl). This has a function [`mtl`](https://metal.juliagpu.org/stable/api/array/#Metal.mtl) which works like `cu`, converting `Array` to `MtlArray`:
9590

9691
```julia
97-
cpu(x) = cpu_device()(x)
98-
gpu(x) = gpu_device()(x)
92+
using Flux, Metal
93+
Metal.allowscalar(false)
94+
95+
m_model = mtl(model)
96+
m_y = m_model(mtl(x))
97+
98+
Flux.gradient((f,x) -> sum(abs2, f(x)), m_model, mtl(x))
9999
```
100100

101-
`gpu_device` performs automatic GPU device selection and returns a device object:
102-
- If no GPU is available, it returns a `CPUDevice` object.
103-
- If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Flux.gpu_backend!(<backend_name>)`. If the trigger package corresponding to the device is not loaded (e.g. with `using CUDA`), then a warning is displayed.
104-
- If no LocalPreferences option is present, then the first working GPU with loaded trigger package is used.
101+
!!! danger "Experimental"
102+
Metal support in Flux is experimental and many features are not yet available.
103+
AMD support is improving, but likely to have more rough edges than CUDA.
104+
105+
If you want your model to work with any brand of GPU, or none, then you may not wish to write `cu` everywhere.
106+
One simple way to be generic is, at the top of the file, to un-comment one of several lines which import a package and assign its "adaptor" to the same name:
105107

106-
Consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):
108+
```julia
109+
using CUDA: cu as device # after this, `device === cu`
110+
# using AMDGPU: roc as device
111+
# device = identity # do-nothing, for CPU
107112

108-
```julia-repl
109-
julia> using Flux, CUDA;
113+
using Flux
114+
model = Chain(...) |> device
115+
```
110116

111-
julia> device = gpu_device() # returns handle to an NVIDIA GPU if available
112-
(::CUDADevice{Nothing}) (generic function with 4 methods)
117+
!!! note "Adapt.jl"
118+
The functions `cu`, `mtl`, `roc` all use [Adapt.jl](https://github.com/JuliaGPU/Adapt.jl), to work within various wrappers.
119+
The reason they work on Flux models is that `Flux.@layer Layer` defines methods of `Adapt.adapt_structure(to, lay::Layer)`.
113120

114-
julia> model = Dense(2 => 3);
115121

116-
julia> model.weight # the model initially lives in CPU memory
117-
3×2 Matrix{Float32}:
118-
-0.984794 -0.904345
119-
0.720379 -0.486398
120-
0.851011 -0.586942
122+
## Automatic GPU choice with `gpu`
121123

122-
julia> model = model |> device # transfer model to the GPU
123-
Dense(2 => 3) # 9 parameters
124+
Flux also provides a more automatic way of choosing which GPU (or none) to use. This is the function `gpu`:
125+
* By default it does nothing.
126+
* If the package CUDA is loaded, and `CUDA.functional() === true`, then it behaves like `cu`.
127+
* If the package AMDGPU is loaded, and `AMDGPU.functional() === true`, then it behaves like `roc`.
128+
* If the package Metal is loaded, and `Metal.functional() === true`, then it behaves like `mtl`.
129+
* If two differnet GPU packages are loaded, the first one takes priority.
130+
131+
For the most part, this means that a script which says `model |> gpu` and `data |> gpu` will just work.
132+
It should always run, and if a GPU package is loaded (and finds the correct hardware) then that will be used.
124133

125-
julia> model.weight
126-
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
127-
-0.984794 -0.904345
128-
0.720379 -0.486398
129-
0.851011 -0.586942
134+
The function `gpu` uses a lower-level function called `get_device()` from [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl),
135+
which checks what to do & then returns some device object. In fact, the entire implementation is just this:
136+
137+
```julia
138+
gpu(x) = gpu_device()(x)
139+
cpu(x) = cpu_device()(x)
130140
```
131141

132142

143+
## Manually selecting devices
144+
145+
I thought there was a whole `Flux.gpu_backend!` and Preferences.jl story we had to tell??
146+
147+
133148
## Transferring Training Data
134149

135150
In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways:
@@ -173,7 +188,7 @@ In order to train the model using the GPU both model and the training data have
173188

174189
## Saving GPU-Trained Models
175190

176-
After the training process is done, one must always transfer the trained model back to the `cpu` memory scope before serializing or saving to disk. This can be done, as described in the previous section, with:
191+
After the training process is done, we must always transfer the trained model back to the CPU memory before serializing or saving to disk. This can be done with `cpu`:
177192
```julia
178193
model = cpu(model) # or model = model |> cpu
179194
```
@@ -275,11 +290,11 @@ Due to a limitation in `Metal.jl`, currently this kind of data movement across d
275290
## Distributed data parallel training
276291

277292
!!! danger "Experimental"
278-
293+
279294
Distributed support is experimental and could change in the future.
280295

281296

282-
Flux supports now distributed data parallel training with `DistributedUtils` module.
297+
Flux supports now distributed data parallel training with `DistributedUtils` module.
283298
If you want to run your code on multiple GPUs, you have to install `MPI.jl` (see [docs](https://juliaparallel.org/MPI.jl/stable/usage/) for more info).
284299

285300
```julia-repl
@@ -347,7 +362,7 @@ DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(
347362
julia> st_opt = Optimisers.setup(opt, model)
348363
(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),)
349364
350-
julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0)
365+
julia> st_opt = DistributedUtils.synchronize!!(backend, st_opt; root=0)
351366
(layers = ((weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(DistributedOptimizer{MPIBackend{Comm}}(MPIBackend{Comm}(Comm(1140850688)), Adam(0.001, (0.9, 0.999), 1.0e-8)), (Float32[0.0], Float32[0.0], (0.9, 0.999))), σ = ())),)
352367
```
353368

@@ -371,7 +386,7 @@ Epoch 3: Loss 0.012763695
371386
Remember that in order to run it on multiple GPUs you have to run from CLI `mpiexecjl --project=. -n <np> julia <filename>.jl`,
372387
where `<np>` is the number of processes that you want to use. The number of processes usually corresponds to the number of gpus.
373388

374-
By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`.
389+
By default `MPI.jl` MPI installation is CUDA-unaware so if you want to run it in CUDA-aware mode, read more [here](https://juliaparallel.org/MPI.jl/stable/usage/#CUDA-aware-MPI-support) on custom installation and rebuilding `MPI.jl`.
375390
Then test if your MPI is CUDA-aware by
376391
```julia-repl
377392
julia> import Pkg
@@ -385,7 +400,7 @@ julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true)
385400
```
386401

387402
!!! warning "Known shortcomings"
388-
403+
389404
We don't run CUDA-aware tests so you're running it at own risk.
390405

391406

@@ -419,4 +434,4 @@ julia> using Metal
419434

420435
julia> Metal.functional()
421436
true
422-
```
437+
```

0 commit comments

Comments
 (0)