Skip to content

Commit c9bab66

Browse files
make gpu(x) = gpu_device()(x) (#2502)
1 parent 31dccd1 commit c9bab66

File tree

18 files changed

+124
-579
lines changed

18 files changed

+124
-579
lines changed

Project.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2929
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3030
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3131
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
32-
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
3332
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
3433
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
3534

@@ -40,7 +39,6 @@ FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
4039
FluxEnzymeExt = "Enzyme"
4140
FluxMPIExt = "MPI"
4241
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
43-
FluxMetalExt = "Metal"
4442

4543
[compat]
4644
AMDGPU = "1"
@@ -50,11 +48,10 @@ ChainRulesCore = "1.12"
5048
Compat = "4.10.0"
5149
Enzyme = "0.12, 0.13"
5250
Functors = "0.4"
53-
MLDataDevices = "1.2.0"
51+
MLDataDevices = "1.4.0"
5452
MLUtils = "0.4"
5553
MPI = "0.20.19"
5654
MacroTools = "0.5"
57-
Metal = "0.5, 1"
5855
NCCL = "0.1.1"
5956
NNlib = "0.9.22"
6057
OneHotArrays = "0.2.4"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ makedocs(
3636
"Flat vs. Nested" => "reference/destructure.md",
3737
"Callback Helpers" => "reference/training/callbacks.md",
3838
"Gradients -- Zygote.jl" => "reference/training/zygote.md",
39+
"Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md",
3940
"Batching Data -- MLUtils.jl" => "reference/data/mlutils.md",
4041
"OneHotArrays.jl" => "reference/data/onehot.md",
4142
"Low-level Operations -- NNlib.jl" => "reference/models/nnlib.md",

docs/src/guide/gpu.md

Lines changed: 77 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -16,68 +16,13 @@ in your code. Notice that for CUDA, explicitly loading also `cuDNN` is not requi
1616
!!! compat "Flux ≤ 0.13"
1717
Old versions of Flux automatically installed CUDA.jl to provide GPU support. Starting from Flux v0.14, CUDA.jl is not a dependency anymore and has to be installed manually.
1818

19-
## Checking GPU Availability
20-
21-
By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following:
22-
23-
```julia
24-
julia> using CUDA
25-
26-
julia> CUDA.functional()
27-
true
28-
```
29-
30-
For AMD GPU:
31-
32-
```julia
33-
julia> using AMDGPU
34-
35-
julia> AMDGPU.functional()
36-
true
37-
38-
julia> AMDGPU.functional(:MIOpen)
39-
true
40-
```
41-
42-
For Metal GPU:
43-
44-
```julia
45-
julia> using Metal
46-
47-
julia> Metal.functional()
48-
true
49-
```
50-
51-
## Selecting GPU backend
52-
53-
Available GPU backends are: `CUDA`, `AMDGPU` and `Metal`.
54-
55-
Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use.
56-
57-
There are two ways you can specify it:
58-
59-
- From the REPL/code in your project, call `Flux.gpu_backend!("AMDGPU")` and restart (if needed) Julia session for the changes to take effect.
60-
- In `LocalPreferences.toml` file in you project directory specify:
61-
```toml
62-
[Flux]
63-
gpu_backend = "AMDGPU"
64-
```
65-
66-
Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable:
67-
68-
```julia
69-
julia> Flux.GPU_BACKEND
70-
"CUDA"
71-
```
72-
73-
The current backend will affect the behaviour of methods like the method `gpu` described below.
7419

7520
## Basic GPU Usage
7621

7722
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), and [Metal.jl](https://github.com/JuliaGPU/Metal.jl).
7823
Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
7924

80-
For example, we can use `CUDA.CuArray` (with the `cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU.
25+
For example, we can use `CUDA.CuArray` (with the `CUDA.cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU.
8126

8227
(Note that you need to have CUDA available to use CUDA.CuArray – please see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) instructions for more details.)
8328

@@ -146,6 +91,50 @@ julia> x |> cpu
14691
0.7766742
14792
```
14893

94+
## Using device objects
95+
96+
In Flux, you can create `device` objects which can be used to easily transfer models and data to GPUs (and defaulting to using the CPU if no GPU backend is available).
97+
These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux uses internally and re-exports.
98+
99+
Device objects can be automatically created using the [`cpu_device`](@ref MLDataDevices.cpu_device) and [`gpu_device`](@ref MLDataDevices.gpu_device) functions. For instance, the `gpu` and `cpu` functions are just convenience functions defined as
100+
101+
```julia
102+
cpu(x) = cpu_device()(x)
103+
gpu(x) = gpu_device()(x)
104+
```
105+
106+
`gpu_device` performs automatic GPU device selection and returns a device object:
107+
- If no GPU is available, it returns a `CPUDevice` object.
108+
- If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Flux.gpu_backend!(<backend_name>)`. If the trigger package corresponding to the device is not loaded (e.g. with `using CUDA`), then a warning is displayed.
109+
- If no LocalPreferences option is present, then the first working GPU with loaded trigger package is used.
110+
111+
Consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):
112+
113+
```julia-repl
114+
julia> using Flux, CUDA;
115+
116+
julia> device = gpu_device() # returns handle to an NVIDIA GPU if available
117+
(::CUDADevice{Nothing}) (generic function with 4 methods)
118+
119+
julia> model = Dense(2 => 3);
120+
121+
julia> model.weight # the model initially lives in CPU memory
122+
3×2 Matrix{Float32}:
123+
-0.984794 -0.904345
124+
0.720379 -0.486398
125+
0.851011 -0.586942
126+
127+
julia> model = model |> device # transfer model to the GPU
128+
Dense(2 => 3) # 9 parameters
129+
130+
julia> model.weight
131+
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
132+
-0.984794 -0.904345
133+
0.720379 -0.486398
134+
0.851011 -0.586942
135+
```
136+
137+
149138
## Transferring Training Data
150139

151140
In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways:
@@ -227,65 +216,8 @@ To select specific devices by device id:
227216
$ export CUDA_VISIBLE_DEVICES='0,1'
228217
```
229218

230-
231219
More information for conditional use of GPUs in CUDA.jl can be found in its [documentation](https://cuda.juliagpu.org/stable/installation/conditional/#Conditional-use), and information about the specific use of the variable is described in the [Nvidia CUDA blog post](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/).
232220

233-
## Using device objects
234-
235-
As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement.
236-
These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux's uses internally and re-exports.
237-
238-
A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.gpu_device) function.
239-
`gpu_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):
240-
241-
```julia-repl
242-
julia> using Flux, CUDA;
243-
244-
julia> device = gpu_device() # returns handle to an NVIDIA GPU if available
245-
(::CUDADevice{Nothing}) (generic function with 4 methods)
246-
247-
julia> model = Dense(2 => 3);
248-
249-
julia> model.weight # the model initially lives in CPU memory
250-
3×2 Matrix{Float32}:
251-
-0.984794 -0.904345
252-
0.720379 -0.486398
253-
0.851011 -0.586942
254-
255-
julia> model = model |> device # transfer model to the GPU
256-
Dense(2 => 3) # 9 parameters
257-
258-
julia> model.weight
259-
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
260-
-0.984794 -0.904345
261-
0.720379 -0.486398
262-
0.851011 -0.586942
263-
```
264-
265-
The device preference can also be set via the [`gpu_backend!`](@ref MLDataDevices.gpu_backend!) function. For instance, below we first set our device preference to `"AMDGPU"`:
266-
267-
```julia-repl
268-
julia> gpu_backend!("AMDGPU")
269-
[ Info: GPU backend has been set to AMDGPU. Restart Julia to use the new backend.
270-
```
271-
If no functional GPU backend is available, the device will default to a CPU device.
272-
You can also explictly request a CPU device by calling the [`cpu_device`](@ref MLDataDevices.cpu_device) function.
273-
274-
```julia-repl
275-
julia> using Flux, MLDataDevices
276-
277-
julia> cdev = cpu_device()
278-
(::CPUDevice{Nothing}) (generic function with 4 methods)
279-
280-
julia> gdev = gpu_device(force=true) # force GPU device, error if no GPU is available
281-
(::CUDADevice{Nothing}) (generic function with 4 methods)
282-
283-
julia> model = Dense(2 => 3); # model in CPU memory
284-
285-
julia> gmodel = model |> gdev; # transfer model to GPU
286-
287-
julia> cmodel = gmodel |> cdev; # transfer model back to CPU
288-
```
289221

290222
## Data movement across GPU devices
291223

@@ -344,24 +276,6 @@ CuDevice(1): NVIDIA TITAN RTX
344276

345277
Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMDGPU` backends.
346278

347-
!!! warning "Printing models after moving to a different device"
348-
349-
Due to a limitation in how GPU packages currently work, printing
350-
models on the REPL after moving them to a GPU device which is different
351-
from the current device will lead to an error.
352-
353-
354-
```@docs
355-
MLDataDevices.cpu_device
356-
MLDataDevices.default_device_rng
357-
MLDataDevices.get_device
358-
MLDataDevices.gpu_device
359-
MLDataDevices.gpu_backend!
360-
MLDataDevices.get_device_type
361-
MLDataDevices.reset_gpu_device!
362-
MLDataDevices.supported_gpu_backends
363-
MLDataDevices.DeviceIterator
364-
```
365279

366280
## Distributed data parallel training
367281

@@ -479,3 +393,35 @@ julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true)
479393

480394
We don't run CUDA-aware tests so you're running it at own risk.
481395

396+
397+
## Checking GPU Availability
398+
399+
By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following:
400+
401+
```julia
402+
julia> using CUDA
403+
404+
julia> CUDA.functional()
405+
true
406+
```
407+
408+
For AMD GPU:
409+
410+
```julia
411+
julia> using AMDGPU
412+
413+
julia> AMDGPU.functional()
414+
true
415+
416+
julia> AMDGPU.functional(:MIOpen)
417+
true
418+
```
419+
420+
For Metal GPU:
421+
422+
```julia
423+
julia> using Metal
424+
425+
julia> Metal.functional()
426+
true
427+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Transferring data across devices
2+
3+
Flux relies on the [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl/blob/main/src/public.jl) package to manage devices and transfer data across them. You don't have to explicitly use the package, as Flux re-exports the necessary functions and types.
4+
5+
```@docs
6+
MLDataDevices.cpu_device
7+
MLDataDevices.default_device_rng
8+
MLDataDevices.functional
9+
MLDataDevices.get_device
10+
MLDataDevices.gpu_device
11+
MLDataDevices.gpu_backend!
12+
MLDataDevices.get_device_type
13+
MLDataDevices.isleaf
14+
MLDataDevices.loaded
15+
MLDataDevices.reset_gpu_device!
16+
MLDataDevices.set_device!
17+
MLDataDevices.supported_gpu_backends
18+
MLDataDevices.DeviceIterator
19+
```

ext/FluxAMDGPUExt/FluxAMDGPUExt.jl

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,22 @@ module FluxAMDGPUExt
33
import ChainRulesCore
44
import ChainRulesCore: NoTangent
55
import Flux
6-
import Flux: FluxCPUAdaptor, FluxAMDGPUAdaptor, _amd, adapt_storage, fmap
6+
import Flux: adapt_storage, fmap
77
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
88
import NNlib
9-
using MLDataDevices: MLDataDevices
9+
using MLDataDevices
1010
using AMDGPU
1111
using Adapt
1212
using Random
1313
using Zygote
1414

1515
const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
1616

17-
# Set to boolean on the first call to check_use_amdgpu
18-
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
19-
20-
21-
function check_use_amdgpu()
22-
if !isnothing(USE_AMDGPU[])
23-
return
24-
end
25-
26-
USE_AMDGPU[] = AMDGPU.functional()
27-
if USE_AMDGPU[]
28-
if !AMDGPU.functional(:MIOpen)
29-
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
30-
end
31-
else
32-
@info """
33-
The AMDGPU function is being called but AMDGPU.jl is not functional.
34-
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
35-
""" maxlog=1
36-
end
37-
return
38-
end
39-
40-
ChainRulesCore.@non_differentiable check_use_amdgpu()
4117

4218
include("functor.jl")
4319
include("batchnorm.jl")
4420
include("conv.jl")
4521

46-
function __init__()
47-
Flux.AMDGPU_LOADED[] = true
48-
end
4922

5023
# TODO
5124
# fail early if input to the model is not on the device (e.g. on the host)

0 commit comments

Comments
 (0)