File tree Expand file tree Collapse file tree 5 files changed +42
-6
lines changed Expand file tree Collapse file tree 5 files changed +42
-6
lines changed Original file line number Diff line number Diff line change @@ -16,7 +16,21 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
16
16
Serialization = " 9e88b42a-f829-5b0c-bbe9-9e923198166b"
17
17
Sockets = " 6462fe0b-24de-5631-8697-dd941f90decc"
18
18
19
+ [extras ]
20
+ AMDGPU = " 21141c5a-9bdb-4563-92ae-f87d6854732e"
21
+ CUDA = " 052768ef-5323-5732-b1bb-66c8b64840ba"
22
+
23
+ [weakdeps ]
24
+ AMDGPU = " 21141c5a-9bdb-4563-92ae-f87d6854732e"
25
+ CUDA = " 052768ef-5323-5732-b1bb-66c8b64840ba"
26
+
27
+ [extensions ]
28
+ AMDGPUExt = " AMDGPU"
29
+ CUDAExt = " CUDA"
30
+
19
31
[compat ]
32
+ AMDGPU = " 0.3, 0.4"
33
+ CUDA = " 3, 4"
20
34
DocStringExtensions = " 0.8, 0.9"
21
35
MPIPreferences = " 0.1.6"
22
36
Requires = " ~0.5, 1.0"
Original file line number Diff line number Diff line change 1
- import . AMDGPU
1
+ module AMDGPUExt
2
+
3
+ import MPI
4
+ isdefined (Base, :get_extension ) ? (import AMDGPU) : (import .. AMDGPU)
5
+ import MPI: MPIPtr, Buffer, Datatype
6
+
2
7
3
8
function Base. cconvert (:: Type{MPIPtr} , A:: AMDGPU.ROCArray{T} ) where T
4
9
A
19
24
function Buffer (arr:: AMDGPU.ROCArray )
20
25
Buffer (arr, Cint (length (arr)), Datatype (eltype (arr)))
21
26
end
27
+
28
+ end # AMDGPUExt
Original file line number Diff line number Diff line change 1
- import . CUDA
1
+ module CUDAExt
2
+
3
+ import MPI
4
+ isdefined (Base, :get_extension ) ? (import CUDA) : (import .. CUDA)
5
+ import MPI: MPIPtr, Buffer, Datatype
2
6
3
7
function Base. cconvert (:: Type{MPIPtr} , buf:: CUDA.CuArray{T} ) where T
4
8
Base. cconvert (CUDA. CuPtr{T}, buf) # returns DeviceBuffer
19
23
function Buffer (arr:: CUDA.CuArray )
20
24
Buffer (arr, Cint (length (arr)), Datatype (eltype (arr)))
21
25
end
26
+
27
+ end # CUDAExt
Original file line number Diff line number Diff line change 1
1
module MPI
2
2
3
3
using Libdl, Serialization
4
- using Requires
5
4
using DocStringExtensions
6
5
import MPIPreferences
7
6
@@ -80,6 +79,10 @@ include("misc.jl")
80
79
81
80
include (" deprecated.jl" )
82
81
82
+ if ! isdefined (Base, :get_extension )
83
+ using Requires
84
+ end
85
+
83
86
function __init__ ()
84
87
MPIPreferences. check_unchanged ()
85
88
@@ -136,8 +139,10 @@ function __init__()
136
139
137
140
run_load_time_hooks ()
138
141
139
- @require AMDGPU= " 21141c5a-9bdb-4563-92ae-f87d6854732e" include (" rocm.jl" )
140
- @require CUDA= " 052768ef-5323-5732-b1bb-66c8b64840ba" include (" cuda.jl" )
142
+ @static if ! isdefined (Base, :get_extension )
143
+ @require AMDGPU= " 21141c5a-9bdb-4563-92ae-f87d6854732e" include (" ../ext/AMDGPUExt.jl" )
144
+ @require CUDA= " 052768ef-5323-5732-b1bb-66c8b64840ba" include (" ../ext/CUDAExt.jl" )
145
+ end
141
146
end
142
147
143
148
end
Original file line number Diff line number Diff line change @@ -6,7 +6,11 @@ using MPIPreferences
6
6
using DoubleFloats
7
7
if get (ENV , " JULIA_MPI_TEST_ARRAYTYPE" , " " ) == " CuArray"
8
8
import CUDA
9
- CUDA. version ()
9
+ if isdefined (CUDA, :versioninfo )
10
+ CUDA. versioninfo ()
11
+ else
12
+ CUDA. version ()
13
+ end
10
14
CUDA. precompile_runtime ()
11
15
ArrayType = CUDA. CuArray
12
16
elseif get (ENV ," JULIA_MPI_TEST_ARRAYTYPE" ," " ) == " ROCArray"
You can’t perform that action at this time.
0 commit comments