Skip to content

Commit e09446f

Browse files
vchuravysimonbyrne
andauthored
Transition GPU support to weak-deps and extensions (#714)
Co-authored-by: Simon Byrne <simonbyrne@gmail.com>
1 parent 92bc4ac commit e09446f

File tree

5 files changed

+42
-6
lines changed

5 files changed

+42
-6
lines changed

Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,21 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1616
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1717
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
1818

19+
[extras]
20+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
21+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
22+
23+
[weakdeps]
24+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
25+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
26+
27+
[extensions]
28+
AMDGPUExt = "AMDGPU"
29+
CUDAExt = "CUDA"
30+
1931
[compat]
32+
AMDGPU = "0.3, 0.4"
33+
CUDA = "3, 4"
2034
DocStringExtensions = "0.8, 0.9"
2135
MPIPreferences = "0.1.6"
2236
Requires = "~0.5, 1.0"

src/rocm.jl renamed to ext/AMDGPUExt.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
import .AMDGPU
1+
module AMDGPUExt
2+
3+
import MPI
4+
isdefined(Base, :get_extension) ? (import AMDGPU) : (import ..AMDGPU)
5+
import MPI: MPIPtr, Buffer, Datatype
6+
27

38
function Base.cconvert(::Type{MPIPtr}, A::AMDGPU.ROCArray{T}) where T
49
A
@@ -19,3 +24,5 @@ end
1924
function Buffer(arr::AMDGPU.ROCArray)
2025
Buffer(arr, Cint(length(arr)), Datatype(eltype(arr)))
2126
end
27+
28+
end # AMDGPUExt

src/cuda.jl renamed to ext/CUDAExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
import .CUDA
1+
module CUDAExt
2+
3+
import MPI
4+
isdefined(Base, :get_extension) ? (import CUDA) : (import ..CUDA)
5+
import MPI: MPIPtr, Buffer, Datatype
26

37
function Base.cconvert(::Type{MPIPtr}, buf::CUDA.CuArray{T}) where T
48
Base.cconvert(CUDA.CuPtr{T}, buf) # returns DeviceBuffer
@@ -19,3 +23,5 @@ end
1923
function Buffer(arr::CUDA.CuArray)
2024
Buffer(arr, Cint(length(arr)), Datatype(eltype(arr)))
2125
end
26+
27+
end #CUDAExt

src/MPI.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
module MPI
22

33
using Libdl, Serialization
4-
using Requires
54
using DocStringExtensions
65
import MPIPreferences
76

@@ -80,6 +79,10 @@ include("misc.jl")
8079

8180
include("deprecated.jl")
8281

82+
if !isdefined(Base, :get_extension)
83+
using Requires
84+
end
85+
8386
function __init__()
8487
MPIPreferences.check_unchanged()
8588

@@ -136,8 +139,10 @@ function __init__()
136139

137140
run_load_time_hooks()
138141

139-
@require AMDGPU="21141c5a-9bdb-4563-92ae-f87d6854732e" include("rocm.jl")
140-
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
142+
@static if !isdefined(Base, :get_extension)
143+
@require AMDGPU="21141c5a-9bdb-4563-92ae-f87d6854732e" include("../ext/AMDGPUExt.jl")
144+
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("../ext/CUDAExt.jl")
145+
end
141146
end
142147

143148
end

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ using MPIPreferences
66
using DoubleFloats
77
if get(ENV, "JULIA_MPI_TEST_ARRAYTYPE", "") == "CuArray"
88
import CUDA
9-
CUDA.version()
9+
if isdefined(CUDA, :versioninfo)
10+
CUDA.versioninfo()
11+
else
12+
CUDA.version()
13+
end
1014
CUDA.precompile_runtime()
1115
ArrayType = CUDA.CuArray
1216
elseif get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "ROCArray"

0 commit comments

Comments
 (0)