Skip to content

Commit c2bd39d

Browse files
authored
Adding device objects for selecting GPU backends (and defaulting to CPU if none exists). (#2297)
* Adding structs for cpu and gpu devices. * Adding implementation of `Flux.get_device()`, which returns the most appropriate GPU backend (or CPU, if nothing is available). * Adding docstrings for the new device types, and the `get_device` function. * Adding `CPU` to the list of supported backends. Made corresponding changes in `gpu(x)`. Adding more details in docstring of `get_device`. * Using `julia-repl` instead of `jldoctest`, and `@info` instead of `@warn`. * Adding `DataLoader` functionality to device objects. * Removing pkgids and defining new functions to check whether backend is available and functional. * Correcting typographical errors, and removing useless imports. * Adding `deviceID` to each device struct, and moving struct definitions to package extensions. * Adding tutorial for using device objects in manual. * Adding docstring for `get_device` in manual, and renaming internal functions. * Minor change in docs. * Removing structs from package extensions as it is bad practice. * Adding more docstrings in manual. * Removing redundant log messages. * Adding kwarg to `get_device` for verbose output. * Setting `deviceID` to `nothing` if GPU is not functional. * Adding basic tests for device objects. * Fixing minor errors in package extensions and tests. * Minor fix in tests + docs. * Moving device tests to extensions, and adding a basic data transfer test. * Moving all device tests in single file per extension.
1 parent c565052 commit c2bd39d

File tree

10 files changed

+413
-5
lines changed

10 files changed

+413
-5
lines changed

docs/src/gpu.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,92 @@ $ export CUDA_VISIBLE_DEVICES='0,1'
231231

232232
More information for conditional use of GPUs in CUDA.jl can be found in its [documentation](https://cuda.juliagpu.org/stable/installation/conditional/#Conditional-use), and information about the specific use of the variable is described in the [Nvidia CUDA blog post](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/).
233233

234+
## Using device objects
235+
236+
As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. To do this, the [`Flux.get_device`](@ref) function can be used.
237+
238+
`Flux.get_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):
239+
240+
```julia-repl
241+
julia> using Flux, CUDA;
242+
243+
julia> device = Flux.get_device(; verbose=true) # returns handle to an NVIDIA GPU
244+
[ Info: Using backend set in preferences: CUDA.
245+
(::Flux.FluxCUDADevice) (generic function with 1 method)
246+
247+
julia> device.deviceID # check the id of the GPU
248+
CuDevice(0): NVIDIA GeForce GTX 1650
249+
250+
julia> model = Dense(2 => 3);
251+
252+
julia> model.weight # the model initially lives in CPU memory
253+
3×2 Matrix{Float32}:
254+
-0.984794 -0.904345
255+
0.720379 -0.486398
256+
0.851011 -0.586942
257+
258+
julia> model = model |> device # transfer model to the GPU
259+
Dense(2 => 3) # 9 parameters
260+
261+
julia> model.weight
262+
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
263+
-0.984794 -0.904345
264+
0.720379 -0.486398
265+
0.851011 -0.586942
266+
267+
```
268+
269+
The device preference can also be set via the [`Flux.gpu_backend!`](@ref) function. For instance, below we first set our device preference to `"CPU"`:
270+
271+
```julia-repl
272+
julia> using Flux; Flux.gpu_backend!("CPU")
273+
┌ Info: New GPU backend set: CPU.
274+
└ Restart your Julia session for this change to take effect!
275+
```
276+
277+
Then, after restarting the Julia session, `Flux.get_device` returns a handle to the `"CPU"`:
278+
279+
```julia-repl
280+
julia> using Flux, CUDA; # even if CUDA is loaded, we'll still get a CPU device
281+
282+
julia> device = Flux.get_device(; verbose=true) # get a CPU device
283+
[ Info: Using backend set in preferences: CPU.
284+
(::Flux.FluxCPUDevice) (generic function with 1 method)
285+
286+
julia> model = Dense(2 => 3);
287+
288+
julia> model = model |> device
289+
Dense(2 => 3) # 9 parameters
290+
291+
julia> model.weight # no change; model still lives on CPU
292+
3×2 Matrix{Float32}:
293+
-0.942968 0.856258
294+
0.440009 0.714106
295+
-0.419192 -0.471838
296+
```
297+
Clearly, this means that the same code will work for any GPU backend and the CPU.
298+
299+
If the preference backend isn't available or isn't functional, then [`Flux.get_device`](@ref) looks for a CUDA, AMD or Metal backend, and returns a corresponding device (if the backend is available and functional). Otherwise, a CPU device is returned. In the below example, the GPU preference is `"CUDA"`:
300+
301+
```julia-repl
302+
julia> using Flux; # preference is CUDA, but CUDA.jl not loaded
303+
304+
julia> device = Flux.get_device(; verbose=true) # this will resort to automatic device selection
305+
[ Info: Using backend set in preferences: CUDA.
306+
┌ Warning: Trying to use backend: CUDA but it's trigger package is not loaded.
307+
│ Please load the package and call this function again to respect the preferences backend.
308+
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:637
309+
[ Info: Using backend: CPU.
310+
(::Flux.FluxCPUDevice) (generic function with 1 method)
311+
```
312+
For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref).
313+
314+
```@docs
315+
Flux.AbstractDevice
316+
Flux.FluxCPUDevice
317+
Flux.FluxCUDADevice
318+
Flux.FluxAMDDevice
319+
Flux.FluxMetalDevice
320+
Flux.supported_devices
321+
Flux.get_device
322+
```

ext/FluxAMDGPUExt/FluxAMDGPUExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
1717
# Set to boolean on the first call to check_use_amdgpu
1818
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
1919

20+
Flux._isavailable(::Flux.FluxAMDDevice) = true
21+
Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional()
22+
2023
function check_use_amdgpu()
2124
if !isnothing(USE_AMDGPU[])
2225
return
@@ -44,6 +47,7 @@ include("conv.jl")
4447

4548
function __init__()
4649
Flux.AMDGPU_LOADED[] = true
50+
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMD"]] = AMDGPU.functional() ? Flux.FluxAMDDevice(AMDGPU.device()) : Flux.FluxAMDDevice(nothing)
4751
end
4852

4953
# TODO

ext/FluxCUDAExt/FluxCUDAExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ import Adapt: adapt_storage
1414

1515
const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing)
1616

17+
Flux._isavailable(::Flux.FluxCUDADevice) = true
18+
Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional()
19+
1720
function check_use_cuda()
1821
if !isnothing(USE_CUDA[])
1922
return
@@ -36,6 +39,9 @@ include("functor.jl")
3639
function __init__()
3740
Flux.CUDA_LOADED[] = true
3841

42+
## add device to available devices
43+
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["CUDA"]] = CUDA.functional() ? Flux.FluxCUDADevice(CUDA.device()) : Flux.FluxCUDADevice(nothing)
44+
3945
try
4046
Base.require(Main, :cuDNN)
4147
catch

ext/FluxMetalExt/FluxMetalExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ using Zygote
1212

1313
const USE_METAL = Ref{Union{Nothing, Bool}}(nothing)
1414

15+
Flux._isavailable(::Flux.FluxMetalDevice) = true
16+
Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional()
17+
1518
function check_use_metal()
1619
isnothing(USE_METAL[]) || return
1720

@@ -30,6 +33,7 @@ include("functor.jl")
3033

3134
function __init__()
3235
Flux.METAL_LOADED[] = true
36+
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]] = Metal.functional() ? Flux.FluxMetalDevice(Metal.current_device()) : Flux.FluxMetalDevice(nothing)
3337
end
3438

3539
end

src/functor.jl

Lines changed: 217 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,16 @@ _isbitsarray(x) = false
187187
_isleaf(::AbstractRNG) = true
188188
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
189189

190-
const GPU_BACKENDS = ("CUDA", "AMD", "Metal")
190+
const GPU_BACKEND_ORDER = sort(
191+
Dict(
192+
"CUDA" => 1,
193+
"AMD" => 2,
194+
"Metal" => 3,
195+
"CPU" => 4,
196+
),
197+
byvalue = true
198+
)
199+
const GPU_BACKENDS = tuple(collect(keys(GPU_BACKEND_ORDER))...)
191200
const GPU_BACKEND = @load_preference("gpu_backend", "CUDA")
192201

193202
function gpu_backend!(backend::String)
@@ -249,6 +258,8 @@ function gpu(x)
249258
gpu(FluxAMDAdaptor(), x)
250259
elseif GPU_BACKEND == "Metal"
251260
gpu(FluxMetalAdaptor(), x)
261+
elseif GPU_BACKEND == "CPU"
262+
cpu(x)
252263
else
253264
error("""
254265
Unsupported GPU backend: $GPU_BACKEND.
@@ -444,3 +455,208 @@ function gpu(d::MLUtils.DataLoader)
444455
d.rng,
445456
)
446457
end
458+
459+
# Defining device interfaces.
460+
"""
461+
Flux.AbstractDevice <: Function
462+
463+
An abstract type representing `device` objects for different GPU backends. The currently supported backends are `"CUDA"`, `"AMD"`, `"Metal"` and `"CPU"`; the `"CPU"` backend is the fallback case when no GPU is available. GPU extensions of Flux define subtypes of this type.
464+
465+
"""
466+
abstract type AbstractDevice <: Function end
467+
468+
function (device::AbstractDevice)(d::MLUtils.DataLoader)
469+
MLUtils.DataLoader(MLUtils.mapobs(device, d.data),
470+
d.batchsize,
471+
d.buffer,
472+
d.partial,
473+
d.shuffle,
474+
d.parallel,
475+
d.collate,
476+
d.rng,
477+
)
478+
end
479+
480+
function _get_device_name(::T)::String where {T <: AbstractDevice} end
481+
482+
## check device availability; more definitions in corresponding extensions
483+
_isavailable(::Nothing) = false
484+
_isfunctional(::Nothing) = false
485+
486+
_isavailable(::AbstractDevice) = false
487+
_isfunctional(::AbstractDevice) = false
488+
489+
"""
490+
Flux.FluxCPUDevice <: Flux.AbstractDevice
491+
492+
A type representing `device` objects for the `"CPU"` backend for Flux. This is the fallback case when no GPU is available to Flux.
493+
"""
494+
Base.@kwdef struct FluxCPUDevice <: AbstractDevice end
495+
496+
(::FluxCPUDevice)(x) = cpu(x)
497+
_isavailable(::FluxCPUDevice) = true
498+
_isfunctional(::FluxCPUDevice) = true
499+
_get_device_name(::FluxCPUDevice) = "CPU"
500+
501+
"""
502+
FluxCUDADevice <: AbstractDevice
503+
504+
A type representing `device` objects for the `"CUDA"` backend for Flux.
505+
"""
506+
Base.@kwdef struct FluxCUDADevice <: AbstractDevice
507+
deviceID
508+
end
509+
510+
(::FluxCUDADevice)(x) = gpu(FluxCUDAAdaptor(), x)
511+
_get_device_name(::FluxCUDADevice) = "CUDA"
512+
513+
"""
514+
FluxAMDDevice <: AbstractDevice
515+
516+
A type representing `device` objects for the `"AMD"` backend for Flux.
517+
"""
518+
Base.@kwdef struct FluxAMDDevice <: AbstractDevice
519+
deviceID
520+
end
521+
522+
(::FluxAMDDevice)(x) = gpu(FluxAMDAdaptor(), x)
523+
_get_device_name(::FluxAMDDevice) = "AMD"
524+
525+
"""
526+
FluxMetalDevice <: AbstractDevice
527+
528+
A type representing `device` objects for the `"Metal"` backend for Flux.
529+
"""
530+
Base.@kwdef struct FluxMetalDevice <: AbstractDevice
531+
deviceID
532+
end
533+
534+
(::FluxMetalDevice)(x) = gpu(FluxMetalAdaptor(), x)
535+
_get_device_name(::FluxMetalDevice) = "Metal"
536+
537+
## device list. order is important
538+
const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS)))
539+
DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice()
540+
541+
## get device
542+
543+
"""
544+
Flux.supported_devices()
545+
546+
Get all supported backends for Flux, in order of preference.
547+
548+
# Example
549+
550+
```jldoctest
551+
julia> using Flux;
552+
553+
julia> Flux.supported_devices()
554+
("CUDA", "AMD", "Metal", "CPU")
555+
```
556+
"""
557+
supported_devices() = GPU_BACKENDS
558+
559+
"""
560+
Flux.get_device(; verbose=false)::AbstractDevice
561+
562+
Returns a `device` object for the most appropriate backend for the current Julia session.
563+
564+
First, the function checks whether a backend preference has been set via the [`Flux.gpu_backend!`](@ref) function. If so, an attempt is made to load this backend. If the corresponding trigger package has been loaded and the backend is functional, a `device` corresponding to the given backend is loaded. Otherwise, the backend is chosen automatically. To update the backend preference, use [`Flux.gpu_backend!`](@ref).
565+
566+
If there is no preference, then for each of the `"CUDA"`, `"AMD"`, `"Metal"` and `"CPU"` backends in the given order, this function checks whether the given backend has been loaded via the corresponding trigger package, and whether the backend is functional. If so, the `device` corresponding to the backend is returned. If no GPU backend is available, a `Flux.FluxCPUDevice` is returned.
567+
568+
If `verbose` is set to `true`, then the function prints informative log messages.
569+
570+
# Examples
571+
For the example given below, the backend preference was set to `"AMD"` via the [`gpu_backend!`](@ref) function.
572+
573+
```julia-repl
574+
julia> using Flux;
575+
576+
julia> model = Dense(2 => 3)
577+
Dense(2 => 3) # 9 parameters
578+
579+
julia> device = Flux.get_device(; verbose=true) # this will just load the CPU device
580+
[ Info: Using backend set in preferences: AMD.
581+
┌ Warning: Trying to use backend: AMD but it's trigger package is not loaded.
582+
│ Please load the package and call this function again to respect the preferences backend.
583+
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:638
584+
[ Info: Using backend: CPU.
585+
(::Flux.FluxCPUDevice) (generic function with 1 method)
586+
587+
julia> model = model |> device
588+
Dense(2 => 3) # 9 parameters
589+
590+
julia> model.weight
591+
3×2 Matrix{Float32}:
592+
-0.304362 -0.700477
593+
-0.861201 0.67825
594+
-0.176017 0.234188
595+
```
596+
597+
Here is the same example, but using `"CUDA"`:
598+
599+
```julia-repl
600+
julia> using Flux, CUDA;
601+
602+
julia> model = Dense(2 => 3)
603+
Dense(2 => 3) # 9 parameters
604+
605+
julia> device = Flux.get_device(; verbose=true)
606+
[ Info: Using backend set in preferences: AMD.
607+
┌ Warning: Trying to use backend: AMD but it's trigger package is not loaded.
608+
│ Please load the package and call this function again to respect the preferences backend.
609+
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:637
610+
[ Info: Using backend: CUDA.
611+
(::Flux.FluxCUDADevice) (generic function with 1 method)
612+
613+
julia> model = model |> device
614+
Dense(2 => 3) # 9 parameters
615+
616+
julia> model.weight
617+
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
618+
0.820013 0.527131
619+
-0.915589 0.549048
620+
0.290744 -0.0592499
621+
```
622+
"""
623+
function get_device(; verbose=false)::AbstractDevice
624+
backend = @load_preference("gpu_backend", nothing)
625+
626+
if backend !== nothing
627+
allowed_backends = supported_devices()
628+
idx = findfirst(isequal(backend), allowed_backends)
629+
if backend allowed_backends
630+
@warn """
631+
`gpu_backend` preference is set to $backend, which is not allowed.
632+
Defaulting to automatic device selection.
633+
""" maxlog=1
634+
else
635+
verbose && @info "Using backend set in preferences: $backend."
636+
device = DEVICES[][idx]
637+
638+
if !_isavailable(device)
639+
@warn """
640+
Trying to use backend: $backend but it's trigger package is not loaded.
641+
Please load the package and call this function again to respect the preferences backend.
642+
"""
643+
else
644+
if _isfunctional(device)
645+
return device
646+
else
647+
@warn "Backend: $backend from the set preferences is not functional. Defaulting to automatic device selection."
648+
end
649+
end
650+
end
651+
end
652+
653+
for backend in GPU_BACKENDS
654+
device = DEVICES[][GPU_BACKEND_ORDER[backend]]
655+
if _isavailable(device)
656+
if _isfunctional(device)
657+
verbose && @info "Using backend: $backend."
658+
return device
659+
end
660+
end
661+
end
662+
end

0 commit comments

Comments
 (0)