Skip to content

Commit f532045

Browse files
allow get_device("Metal") and informative error messages (#2319)
* Update Project.toml * Update Project.toml * cl/ext * cleanup * remove some is functional guard * cleanup * ordinal -> id * fix tests * cleanup buildkite * user facing error
1 parent b887018 commit f532045

File tree

14 files changed

+159
-154
lines changed

14 files changed

+159
-154
lines changed

.buildkite/pipeline.yml

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,5 @@
11
steps:
2-
# - label: "GPU integration with julia v1.9"
3-
# plugins:
4-
# - JuliaCI/julia#v1:
5-
# # Drop default "registries" directory, so it is not persisted from execution to execution
6-
# # Taken from https://github.com/JuliaLang/julia/blob/v1.7.2/.buildkite/pipelines/main/platforms/package_linux.yml#L11-L12
7-
# persist_depot_dirs: packages,artifacts,compiled
8-
# version: "1.9"
9-
# - JuliaCI/julia-test#v1: ~
10-
# agents:
11-
# queue: "juliagpu"
12-
# cuda: "*"
13-
# timeout_in_minutes: 60
14-
15-
- label: "GPU integration with julia v1"
2+
- label: "CUDA GPU with julia v1"
163
plugins:
174
- JuliaCI/julia#v1:
185
version: "1"
@@ -24,6 +11,7 @@ steps:
2411
cuda: "*"
2512
env:
2613
JULIA_CUDA_USE_BINARYBUILDER: "true"
14+
FLUX_TEST_CUDA: "true"
2715
FLUX_TEST_CPU: "false"
2816
timeout_in_minutes: 60
2917

@@ -36,6 +24,7 @@ steps:
3624
# queue: "juliagpu"
3725
# cuda: "*"
3826
# timeout_in_minutes: 60
27+
3928
- label: "Metal with julia {{matrix.julia}}"
4029
plugins:
4130
- JuliaCI/julia#v1:
@@ -57,7 +46,7 @@ steps:
5746
if: build.message !~ /\[skip tests\]/
5847
timeout_in_minutes: 60
5948
env:
60-
FLUX_TEST_METAL: 'true'
49+
FLUX_TEST_METAL: "true"
6150
FLUX_TEST_CPU: "false"
6251
matrix:
6352
setup:
@@ -84,7 +73,7 @@ steps:
8473
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
8574
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
8675
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
87-
FLUX_TEST_AMDGPU: true
76+
FLUX_TEST_AMDGPU: "true"
8877
FLUX_TEST_CPU: "false"
8978
JULIA_NUM_THREADS: 4
9079
env:

docs/src/gpu.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ CUDA.DeviceIterator() for 3 devices:
327327
328328
```
329329

330-
Then, let's select the device with ordinal `0`:
330+
Then, let's select the device with id `0`:
331331

332332
```julia-repl
333333
julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMD"
@@ -354,7 +354,7 @@ CuDevice(0): GeForce RTX 2080 Ti
354354
355355
```
356356

357-
Next, we'll get a handle to the device with ordinal `1`, and move `dense_model` to that device:
357+
Next, we'll get a handle to the device with id `1`, and move `dense_model` to that device:
358358

359359
```julia-repl
360360
julia> device1 = Flux.get_device("CUDA", 1)

ext/FluxAMDGPUExt/functor.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Convert Float64 to Float32, but preserve Float16.
22
function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
3-
if to.ordinal === nothing
3+
if to.id === nothing
44
if (typeof(x) <: AbstractArray{Float16, N} where N)
55
N = length(size(x))
66
return isbits(x) ? x : ROCArray{Float16, N}(x)
@@ -12,10 +12,10 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
1212
end
1313
end
1414

15-
old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0
15+
old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0
1616

1717
if !(x isa ROCArray)
18-
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1]) # adding 1 because ordinals start from 0
18+
AMDGPU.device!(AMDGPU.devices()[to.id + 1]) # adding 1 because ids start from 0
1919
if (typeof(x) <: AbstractArray{Float16, N} where N)
2020
N = length(size(x))
2121
x_new = isbits(x) ? x : ROCArray{Float16, N}(x)
@@ -25,14 +25,14 @@ function adapt_storage(to::FluxAMDAdaptor, x::AbstractArray)
2525
else
2626
x_new = isbits(x) ? x : ROCArray(x)
2727
end
28-
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
28+
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
2929
return x_new
30-
elseif AMDGPU.device_id(AMDGPU.device(x)) == to.ordinal
30+
elseif AMDGPU.device_id(AMDGPU.device(x)) == to.id
3131
return x
3232
else
33-
AMDGPU.device!(AMDGPU.devices()[to.ordinal + 1])
33+
AMDGPU.device!(AMDGPU.devices()[to.id + 1])
3434
x_new = copy(x)
35-
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
35+
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
3636
return x_new
3737
end
3838
end
@@ -76,10 +76,10 @@ Flux._isleaf(::AMD_CONV) = true
7676
_exclude(x) = Flux._isleaf(x)
7777
_exclude(::CPU_CONV) = true
7878

79-
function _amd(ordinal::Union{Nothing, Int}, x)
79+
function _amd(id::Union{Nothing, Int}, x)
8080
check_use_amdgpu()
8181
USE_AMDGPU[] || return x
82-
fmap(x -> Adapt.adapt(FluxAMDAdaptor(ordinal), x), x; exclude=_exclude)
82+
fmap(x -> Adapt.adapt(FluxAMDAdaptor(id), x), x; exclude=_exclude)
8383
end
8484

8585
# CPU -> GPU
@@ -106,10 +106,10 @@ function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMD_CONV)
106106
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
107107
end
108108

109-
function Flux.get_device(::Val{:AMD}, ordinal::Int) # ordinal should start from 0
110-
old_ordinal = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ordinals start from 0
111-
AMDGPU.device!(AMDGPU.devices()[ordinal + 1]) # adding 1 because ordinals start from 0
109+
function Flux.get_device(::Val{:AMD}, id::Int) # id should start from 0
110+
old_id = AMDGPU.device_id(AMDGPU.device()) - 1 # subtracting 1 because ids start from 0
111+
AMDGPU.device!(AMDGPU.devices()[id + 1]) # adding 1 because ids start from 0
112112
device = Flux.FluxAMDDevice(AMDGPU.device())
113-
AMDGPU.device!(AMDGPU.devices()[old_ordinal + 1])
113+
AMDGPU.device!(AMDGPU.devices()[old_id + 1])
114114
return device
115115
end

ext/FluxCUDAExt/functor.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
2+
23
function adapt_storage(to::FluxCUDAAdaptor, x::AbstractArray)
3-
to.ordinal === nothing && return CUDA.cu(x)
4+
to.id === nothing && return CUDA.cu(x)
45

56
# remember current device
6-
old_ordinal = CUDA.device().handle
7+
old_id = CUDA.device().handle
78

89
if !(x isa CuArray)
9-
CUDA.device!(to.ordinal)
10+
CUDA.device!(to.id)
1011
x_new = CUDA.cu(x)
11-
CUDA.device!(old_ordinal)
12+
CUDA.device!(old_id)
1213
return x_new
13-
elseif CUDA.device(x).handle == to.ordinal
14+
elseif CUDA.device(x).handle == to.id
1415
return x
1516
else
16-
CUDA.device!(to.ordinal)
17+
CUDA.device!(to.id)
1718
x_new = copy(x)
18-
CUDA.device!(old_ordinal)
19+
CUDA.device!(old_id)
1920
return x_new
2021
end
2122
end
23+
2224
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
2325
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
2426
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
@@ -44,16 +46,16 @@ ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) =
4446
ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) =
4547
adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ)))
4648

47-
function _cuda(ordinal::Union{Nothing, Int}, x)
49+
function _cuda(id::Union{Nothing, Int}, x)
4850
check_use_cuda()
4951
USE_CUDA[] || return x
50-
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(ordinal), x), x; exclude=Flux._isleaf)
52+
fmap(x -> Adapt.adapt(FluxCUDAAdaptor(id), x), x; exclude=Flux._isleaf)
5153
end
5254

53-
function Flux.get_device(::Val{:CUDA}, ordinal::Int)
54-
old_ordinal = CUDA.device().handle
55-
CUDA.device!(ordinal)
55+
function Flux.get_device(::Val{:CUDA}, id::Int)
56+
old_id = CUDA.device().handle
57+
CUDA.device!(id)
5658
device = Flux.FluxCUDADevice(CUDA.device())
57-
CUDA.device!(old_ordinal)
59+
CUDA.device!(old_id)
5860
return device
5961
end

ext/FluxMetalExt/functor.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ function _metal(x)
3232
USE_METAL[] || return x
3333
fmap(x -> Adapt.adapt(FluxMetalAdaptor(), x), x; exclude=_isleaf)
3434
end
35+
36+
function Flux.get_device(::Val{:Metal}, id::Int)
37+
@assert id == 0 "Metal backend only supports one device at the moment"
38+
return Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["Metal"]]
39+
end
40+

src/functor.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,14 @@ trainable(c::Cholesky) = ()
333333
# CUDA extension. ########
334334

335335
Base.@kwdef struct FluxCUDAAdaptor
336-
ordinal::Union{Nothing, Int} = nothing
336+
id::Union{Nothing, Int} = nothing
337337
end
338338

339339
const CUDA_LOADED = Ref{Bool}(false)
340340

341341
function gpu(to::FluxCUDAAdaptor, x)
342342
if CUDA_LOADED[]
343-
return _cuda(to.ordinal, x)
343+
return _cuda(to.id, x)
344344
else
345345
@info """
346346
The CUDA functionality is being called but
@@ -356,14 +356,14 @@ function _cuda end
356356
# AMDGPU extension. ########
357357

358358
Base.@kwdef struct FluxAMDAdaptor
359-
ordinal::Union{Nothing, Int} = nothing
359+
id::Union{Nothing, Int} = nothing
360360
end
361361

362362
const AMDGPU_LOADED = Ref{Bool}(false)
363363

364364
function gpu(to::FluxAMDAdaptor, x)
365365
if AMDGPU_LOADED[]
366-
return _amd(to.ordinal, x)
366+
return _amd(to.id, x)
367367
else
368368
@info """
369369
The AMDGPU functionality is being called but
@@ -650,10 +650,10 @@ function get_device(; verbose=false)::AbstractDevice
650650
end
651651

652652
"""
653-
Flux.get_device(backend::String, ordinal::Int = 0)::Flux.AbstractDevice
653+
Flux.get_device(backend::String, idx::Int = 0)::Flux.AbstractDevice
654654
655-
Get a device object for a backend specified by the string `backend` and `ordinal`. The currently supported values
656-
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `ordinal` must be an integer value between `0` and the number of available devices.
655+
Get a device object for a backend specified by the string `backend` and `idx`. The currently supported values
656+
of `backend` are `"CUDA"`, `"AMD"` and `"CPU"`. `idx` must be an integer value between `0` and the number of available devices.
657657
658658
# Examples
659659
@@ -683,10 +683,15 @@ julia> cpu_device = Flux.get_device("CPU")
683683
684684
```
685685
"""
686-
function get_device(backend::String, ordinal::Int = 0)
686+
function get_device(backend::String, idx::Int = 0)
687687
if backend == "CPU"
688688
return FluxCPUDevice()
689689
else
690-
return get_device(Val(Symbol(backend)), ordinal)
690+
return get_device(Val(Symbol(backend)), idx)
691691
end
692692
end
693+
694+
# Fallback
695+
function get_device(::Val{D}, idx) where D
696+
error("Unsupported backend: $(D). Try importing the corresponding package.")
697+
end

test/ext_amdgpu/get_devices.jl

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,41 @@ amd_device = Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMD"]]
33
# should pass, whether or not AMDGPU is functional
44
@test typeof(amd_device) <: Flux.FluxAMDDevice
55

6-
if AMDGPU.functional()
7-
@test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice
8-
else
9-
@test typeof(amd_device.deviceID) <: Nothing
10-
end
6+
@test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice
117

128
# testing get_device
139
dense_model = Dense(2 => 3) # initially lives on CPU
1410
weight = copy(dense_model.weight) # store the weight
1511
bias = copy(dense_model.bias) # store the bias
16-
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
17-
amd_device = Flux.get_device()
18-
19-
@test typeof(amd_device) <: Flux.FluxAMDDevice
20-
@test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice
21-
@test Flux._get_device_name(amd_device) in Flux.supported_devices()
22-
23-
# correctness of data transfer
24-
x = randn(5, 5)
25-
cx = x |> amd_device
26-
@test cx isa AMDGPU.ROCArray
27-
@test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amd_device.deviceID)
28-
29-
# moving models to specific NVIDIA devices
30-
for ordinal in 0:(length(AMDGPU.devices()) - 1)
31-
current_amd_device = Flux.get_device("AMD", ordinal)
32-
@test typeof(current_amd_device.deviceID) <: AMDGPU.HIPDevice
33-
@test AMDGPU.device_id(current_amd_device.deviceID) == ordinal + 1
34-
35-
global dense_model = dense_model |> current_amd_device
36-
@test dense_model.weight isa AMDGPU.ROCArray
37-
@test dense_model.bias isa AMDGPU.ROCArray
38-
@test AMDGPU.device_id(AMDGPU.device(dense_model.weight)) == ordinal + 1
39-
@test AMDGPU.device_id(AMDGPU.device(dense_model.bias)) == ordinal + 1
40-
@test isequal(Flux.cpu(dense_model.weight), weight)
41-
@test isequal(Flux.cpu(dense_model.bias), bias)
42-
end
43-
# finally move to CPU, and see if things work
44-
cpu_device = Flux.get_device("CPU")
45-
dense_model = cpu_device(dense_model)
46-
@test dense_model.weight isa Matrix
47-
@test dense_model.bias isa Vector
4812

13+
amd_device = Flux.get_device()
14+
15+
@test typeof(amd_device) <: Flux.FluxAMDDevice
16+
@test typeof(amd_device.deviceID) <: AMDGPU.HIPDevice
17+
@test Flux._get_device_name(amd_device) in Flux.supported_devices()
18+
19+
# correctness of data transfer
20+
x = randn(5, 5)
21+
cx = x |> amd_device
22+
@test cx isa AMDGPU.ROCArray
23+
@test AMDGPU.device_id(AMDGPU.device(cx)) == AMDGPU.device_id(amd_device.deviceID)
24+
25+
# moving models to specific NVIDIA devices
26+
for id in 0:(length(AMDGPU.devices()) - 1)
27+
current_amd_device = Flux.get_device("AMD", id)
28+
@test typeof(current_amd_device.deviceID) <: AMDGPU.HIPDevice
29+
@test AMDGPU.device_id(current_amd_device.deviceID) == id + 1
30+
31+
global dense_model = dense_model |> current_amd_device
32+
@test dense_model.weight isa AMDGPU.ROCArray
33+
@test dense_model.bias isa AMDGPU.ROCArray
34+
@test AMDGPU.device_id(AMDGPU.device(dense_model.weight)) == id + 1
35+
@test AMDGPU.device_id(AMDGPU.device(dense_model.bias)) == id + 1
36+
@test isequal(Flux.cpu(dense_model.weight), weight)
37+
@test isequal(Flux.cpu(dense_model.bias), bias)
4938
end
39+
# finally move to CPU, and see if things work
40+
cpu_device = Flux.get_device("CPU")
41+
dense_model = cpu_device(dense_model)
42+
@test dense_model.weight isa Matrix
43+
@test dense_model.bias isa Vector

test/ext_amdgpu/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ AMDGPU.allowscalar(false)
55
include("../test_utils.jl")
66
include("test_utils.jl")
77

8+
@testset "get_devices" begin
9+
include("get_devices.jl")
10+
end
11+
812
@testset "Basic" begin
913
include("basic.jl")
1014
end

0 commit comments

Comments
 (0)