Skip to content

Commit 36fbdf1

Browse files
authored
Implement interface for data transfer across GPU devices. (#2308)
* Adding new `get_device` method to return a CUDA device with particular ordinal. * Adding an `adapt` function for `AbstractArray` to handle movement across devices. * Making the `get_device` interface simpler, and some minor changes. * Adding CPU option to `get_device`. * Removing `KernelAbstractions` from deps. * Adding new `get_device` method to return a particular AMD device. * Adding new `adapt_storage` function for moving arrays. Also passing ordinal information through `FluxAMDAdaptor`. * Moving relevant function definitions to extensions. * Making `_metal` accept an ordinal. * Adding new `get_device` method to return particular Metal device. * Adding new `adapt_storage` method for metal arrays. * Fixing minor error. * Fixing minor error and spelling mistake. * Fixing package name: `AMDGPU` instead of `AMD`. * Reverting back to old metal functionality. * Adding tests for moving models between CPU and NVIDIA devices. * Adding tests for data movement on AMD devices. * Fixing index error while choosing AMD gpu device. * Fixing AMD ordinal starting index. * Adding docstring for new `get_device` method. * Removing global name conflicts in tests. * Minor fix to AMDs device id tests. * Disambiguating test variables. * Adding more info in docstring of `get_device`, and writing some documentation in the guide. * Fixing minor error in AMD code. * Fixing yet another ordinal index error in AMD code. * Fixing another ordinal index error in AMD code. * Fixing spelling mistake. * Replacing type checks for `nothing` but equality checks.
1 parent 656175d commit 36fbdf1

File tree

12 files changed

+291
-54
lines changed

12 files changed

+291
-54
lines changed

Project.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2424
[weakdeps]
2525
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2626
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
27-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2827
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
28+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2929

3030
[extensions]
3131
FluxAMDGPUExt = "AMDGPU"
@@ -56,8 +56,8 @@ julia = "1.9"
5656
[extras]
5757
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
5858
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
59-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
6059
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
60+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
6161
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
6262
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
6363
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
@@ -68,6 +68,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6868
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
6969

7070
[targets]
71-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra",
72-
"FillArrays", "ComponentArrays", "BSON", "Pkg",
73-
"CUDA", "cuDNN", "Metal", "AMDGPU"]
71+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU"]

docs/src/gpu.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,71 @@ julia> device = Flux.get_device(; verbose=true) # this will resort to auto
311311
```
312312
For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref).
313313

314+
## Data movement across GPU devices
315+
316+
Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU
317+
device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices:
318+
319+
```julia-repl
320+
julia> using Flux, CUDA;
321+
322+
julia> CUDA.devices()
323+
CUDA.DeviceIterator() for 3 devices:
324+
0. GeForce RTX 2080 Ti
325+
1. GeForce RTX 2080 Ti
326+
2. TITAN X (Pascal)
327+
328+
```
329+
330+
Then, let's select the device with ordinal `0`:
331+
332+
```julia-repl
333+
julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMD"
334+
(::Flux.FluxCUDADevice) (generic function with 1 method)
335+
336+
```
337+
338+
Then, let's move a simple dense layer to the GPU represented by `device0`:
339+
340+
```julia-repl
341+
julia> dense_model = Dense(2 => 3)
342+
Dense(2 => 3) # 9 parameters
343+
344+
julia> dense_model = dense_model |> device0;
345+
346+
julia> dense_model.weight
347+
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
348+
0.695662 0.816299
349+
-0.204763 -0.10232
350+
-0.955829 0.538412
351+
352+
julia> CUDA.device(dense_model.weight) # check the GPU to which dense_model is attached
353+
CuDevice(0): GeForce RTX 2080 Ti
354+
355+
```
356+
357+
Next, we'll get a handle to the device with ordinal `1`, and move `dense_model` to that device:
358+
359+
```julia-repl
360+
julia> device1 = Flux.get_device("CUDA", 1)
361+
(::Flux.FluxCUDADevice) (generic function with 1 method)
362+
363+
julia> dense_model = dense_model |> device1; # don't directly print the model; see warning below
364+
365+
julia> CUDA.device(dense_model.weight)
366+
CuDevice(1): GeForce RTX 2080 Ti
367+
368+
```
369+
370+
Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMD` backends.
371+
372+
!!! warning "Printing models after moving to a different device"
373+
374+
Due to a limitation in how GPU packages currently work, printing
375+
models on the REPL after moving them to a GPU device which is different
376+
from the current device will lead to an error.
377+
378+
314379
```@docs
315380
Flux.AbstractDevice
316381
Flux.FluxCPUDevice

ext/FluxAMDGPUExt/FluxAMDGPUExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
1717
# Set to boolean on the first call to check_use_amdgpu
1818
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
1919

20+
function (device::Flux.FluxAMDDevice)(x)
21+
if device.deviceID === nothing
22+
Flux.gpu(Flux.FluxAMDAdaptor(), x)
23+
else
24+
return Flux.gpu(Flux.FluxAMDAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
25+
end
26+
end
27+
Flux._get_device_name(::Flux.FluxAMDDevice) = "AMD"
2028
Flux._isavailable(::Flux.FluxAMDDevice) = true
2129
Flux._isfunctional(::Flux.FluxAMDDevice) = AMDGPU.functional()
2230

ext/FluxAMDGPUExt/functor.jl

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,41 @@
11
# Convert Float64 to Float32, but preserve Float16.
2-
adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray =
3-
isbits(x) ? x : ROCArray(x)
4-
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} =
5-
isbits(x) ? x : ROCArray{Float32, N}(x)
6-
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N =
7-
isbits(x) ? x : ROCArray{Float16, N}(x)
2+
function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
3+
if to.ordinal === nothing
4+
if (typeof(x) <: AbstractArray{Float16, N} where N)
5+
N = length(size(x))
6+
return isbits(x) ? x : ROCArray{Float16, N}(x)
7+
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
8+
N = length(size(x))
9+
return isbits(x) ? x : ROCArray{Float32, N}(x)
10+
else
11+
return isbits(x) ? x : ROCArray(x)
12+
end
13+
end
14+
15+
old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0
16+
17+
if !(x isa ROCArray)
18+
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0
19+
if (typeof(x) <: AbstractArray{Float16, N} where N)
20+
N = length(size(x))
21+
x_new = isbits(x) ? x : ROCArray{Float16, N}(x)
22+
elseif (typeof(x) <: AbstractArray{T, N} where {T <: AbstractFloat, N})
23+
N = length(size(x))
24+
x_new = isbits(x) ? x : ROCArray{Float32, N}(x)
25+
else
26+
x_new = isbits(x) ? x : ROCArray(x)
27+
end
28+
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
29+
return x_new
30+
elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal
31+
return x
32+
else
33+
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1])
34+
x_new = copy(x)
35+
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
36+
return x_new
37+
end
38+
end
839

940
adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) =
1041
ROCArray(collect(x))
@@ -45,10 +76,10 @@ Flux._isleaf(::AMD_CONV) = true
4576
_exclude(x) = Flux._isleaf(x)
4677
_exclude(::CPU_CONV) = true
4778

48-
function _amd(x)
79+
function _amd(ordinal::Union{Nothing, Int}, x)
4980
check_use_amdgpu()
5081
USE_AMDGPU[] || return x
51-
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_exclude)
82+
fmap(x -> Adapt.adapt(FluxAMDAdaptor(ordinal), x), x; exclude=_exclude)
5283
end
5384

5485
# CPU -> GPU
@@ -74,3 +105,11 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV)
74105
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
75106
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
76107
end
108+
109+
function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0
110+
old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0
111+
AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0
112+
device = Flux.FluxAMDDevice(AMDGPU.device())
113+
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
114+
return device
115+
end

ext/FluxCUDAExt/FluxCUDAExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ import Adapt: adapt_storage
1414

1515
const USE_CUDA = Ref{Union{Nothing, Bool}}(nothing)
1616

17+
function (device::Flux.FluxCUDADevice)(x)
18+
if device.deviceID === nothing
19+
return Flux.gpu(Flux.FluxCUDAAdaptor(), x)
20+
else
21+
return Flux.gpu(Flux.FluxCUDAAdaptor(device.deviceID.handle), x)
22+
end
23+
end
24+
Flux._get_device_name(::Flux.FluxCUDADevice) = "CUDA"
1725
Flux._isavailable(::Flux.FluxCUDADevice) = true
1826
Flux._isfunctional(::Flux.FluxCUDADevice) = CUDA.functional()
1927

ext/FluxCUDAExt/functor.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
1-
21
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
2+
function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray)
3+
to.ordinal === nothing && return CUDA.cu(x)
4+
5+
# remember current device
6+
old_ordinal = CUDA.device().handle
7+
8+
if !(x isa CuArray)
9+
CUDA.device!(to.ordinal)
10+
x_new = CUDA.cu(x)
11+
CUDA.device!(old_ordinal)
12+
return x_new
13+
elseif CUDA.device(x).handle == to.ordinal
14+
return x
15+
else
16+
CUDA.device!(to.ordinal)
17+
x_new = copy(x)
18+
CUDA.device!(old_ordinal)
19+
return x_new
20+
end
21+
end
322
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
423
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
524
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
@@ -25,8 +44,16 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) =
2544
ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) =
2645
adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ)))
2746

28-
function _cuda(x)
47+
function _cuda(ordinal::Union{Nothing, Int}, x)
2948
check_use_cuda()
3049
USE_CUDA[] || return x
31-
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude=Flux._isleaf)
50+
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf)
51+
end
52+
53+
function Flux.get_device(::Val{:CUDA}, ordinal::Int)
54+
old_ordinal = CUDA.device().handle
55+
CUDA.device!(ordinal)
56+
device = Flux.FluxCUDADevice(CUDA.device())
57+
CUDA.device!(old_ordinal)
58+
return device
3259
end

ext/FluxMetalExt/FluxMetalExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using Zygote
1212

1313
const USE_METAL = Ref{Union{Nothing, Bool}}(nothing)
1414

15+
(::Flux.FluxMetalDevice)(x) = Flux.gpu(Flux.FluxMetalAdaptor(), x)
16+
Flux._get_device_name(::Flux.FluxMetalDevice) = "Metal"
1517
Flux._isavailable(::Flux.FluxMetalDevice) = true
1618
Flux._isfunctional(::Flux.FluxMetalDevice) = Metal.functional()
1719

src/functor.jl

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -332,13 +332,15 @@ trainable(c::Cholesky) = ()
332332

333333
# CUDA extension. ########
334334

335-
struct FluxCUDAAdaptor end
335+
Base.@kwdef struct FluxCUDAAdaptor
336+
ordinal::Union{Nothing, Int} = nothing
337+
end
336338

337339
const CUDA_LOADED = Ref{Bool}(false)
338340

339-
function gpu(::FluxCUDAAdaptor, x)
341+
function gpu(to::FluxCUDAAdaptor, x)
340342
if CUDA_LOADED[]
341-
return _cuda(x)
343+
return _cuda(to.ordinal, x)
342344
else
343345
@info """
344346
The CUDA functionality is being called but
@@ -353,13 +355,15 @@ function _cuda end
353355

354356
# AMDGPU extension. ########
355357

356-
struct FluxAMDAdaptor end
358+
Base.@kwdef struct FluxAMDAdaptor
359+
ordinal::Union{Nothing, Int} = nothing
360+
end
357361

358362
const AMDGPU_LOADED = Ref{Bool}(false)
359363

360-
function gpu(::FluxAMDAdaptor, x)
364+
function gpu(to::FluxAMDAdaptor, x)
361365
if AMDGPU_LOADED[]
362-
return _amd(x)
366+
return _amd(to.ordinal, x)
363367
else
364368
@info """
365369
The AMDGPU functionality is being called but
@@ -500,9 +504,6 @@ Base.@kwdef struct FluxCUDADevice <: AbstractDevice
500504
deviceID
501505
end
502506

503-
(::FluxCUDADevice)(x) = gpu(FluxCUDAAdaptor(), x)
504-
_get_device_name(::FluxCUDADevice) = "CUDA"
505-
506507
"""
507508
FluxAMDDevice <: AbstractDevice
508509
@@ -512,9 +513,6 @@ Base.@kwdef struct FluxAMDDevice <: AbstractDevice
512513
deviceID
513514
end
514515

515-
(::FluxAMDDevice)(x) = gpu(FluxAMDAdaptor(), x)
516-
_get_device_name(::FluxAMDDevice) = "AMD"
517-
518516
"""
519517
FluxMetalDevice <: AbstractDevice
520518
@@ -524,9 +522,6 @@ Base.@kwdef struct FluxMetalDevice <: AbstractDevice
524522
deviceID
525523
end
526524

527-
(::FluxMetalDevice)(x) = gpu(FluxMetalAdaptor(), x)
528-
_get_device_name(::FluxMetalDevice) = "Metal"
529-
530525
## device list. order is important
531526
const DEVICES = Ref{Vector{Union{Nothing, AbstractDevice}}}(Vector{Union{Nothing, AbstractDevice}}(nothing, length(GPU_BACKENDS)))
532527
DEVICES[][GPU_BACKEND_ORDER["CPU"]] = FluxCPUDevice()
@@ -550,7 +545,7 @@ julia> Flux.supported_devices()
550545
supported_devices() = GPU_BACKENDS
551546

552547
"""
553-
Flux.get_device(; verbose=false)::AbstractDevice
548+
Flux.get_device(; verbose=false)::Flux.AbstractDevice
554549
555550
Returns a `device` object for the most appropriate backend for the current Julia session.
556551
@@ -653,3 +648,45 @@ function get_device(; verbose=false)::AbstractDevice
653648
end
654649
end
655650
end
651+
652+
"""
653+
Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice
654+
655+
Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values
656+
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `ordinal` must be an integer value between `0` and the number of available devices.
657+
658+
# Examples
659+
660+
```julia-repl
661+
julia> using Flux, CUDA;
662+
663+
julia> CUDA.devices()
664+
CUDA.DeviceIterator() for 3 devices:
665+
0. GeForce RTX 2080 Ti
666+
1. GeForce RTX 2080 Ti
667+
2. TITAN X (Pascal)
668+
669+
julia> device0 = Flux.get_device("CUDA", 0)
670+
(::Flux.FluxCUDADevice) (generic function with 1 method)
671+
672+
julia> device0.deviceID
673+
CuDevice(0): GeForce RTX 2080 Ti
674+
675+
julia> device1 = Flux.get_device("CUDA", 1)
676+
(::Flux.FluxCUDADevice) (generic function with 1 method)
677+
678+
julia> device1.deviceID
679+
CuDevice(1): GeForce RTX 2080 Ti
680+
681+
julia> cpu_device = Flux.get_device("CPU")
682+
(::Flux.FluxCPUDevice) (generic function with 1 method)
683+
684+
```
685+
"""
686+
function get_device(backend::String, ordinal::Int = 0)
687+
if backend == "CPU"
688+
return FluxCPUDevice()
689+
else
690+
return get_device(Val(Symbol(backend)), ordinal)
691+
end
692+
end

0 commit comments

Comments
 (0)