Skip to content

Commit 5c1444f

Browse files
oschulzvchuravy
andauthored
Add function get_device (#269)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent 7e6371a commit 5c1444f

File tree

12 files changed

+59
-16
lines changed

12 files changed

+59
-16
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
docs/build/
2+
3+
.vscode
4+

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ version = "0.8.0-dev"
66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1012
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1113
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
1214

examples/matmul.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ function matmul!(a, b, c)
2424
println("Matrix size mismatch!")
2525
return nothing
2626
end
27-
if isa(a, Array)
28-
kernel! = matmul_kernel!(CPU(),4)
29-
else
30-
kernel! = matmul_kernel!(CUDADevice(),256)
31-
end
27+
device = KernelAbstractions.get_device(a)
28+
n = device isa GPU ? 256 : 4
29+
kernel! = matmul_kernel!(device, n)
3230
kernel!(a, b, c, ndrange=size(c))
3331
end
3432

examples/naive_transpose.jl

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,9 @@ function naive_transpose!(a, b)
1717
println("Matrix size mismatch!")
1818
return nothing
1919
end
20-
21-
if isa(a, Array)
22-
kernel! = naive_transpose_kernel!(CPU(), 4)
23-
elseif isa(a, CuArray)
24-
kernel! = naive_transpose_kernel!(CUDADevice(), 256)
25-
elseif isa(a, ROCArray)
26-
kernel! = naive_transpose_kernel!(ROCDevice(), 256)
27-
else
28-
println("Unrecognized array type!")
29-
end
30-
20+
device = KernelAbstractions.get_device(a)
21+
n = device isa GPU ? 256 : 4
22+
kernel! = naive_transpose_kernel!(device, n)
3123
kernel!(a, b, ndrange=size(a))
3224
end
3325

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import KernelAbstractions
88

99
export CUDADevice
1010

11+
KernelAbstractions.get_device(::CUDA.CuArray) = CUDADevice()
12+
KernelAbstractions.get_device(::CUDA.CUSPARSE.AbstractCuSparseArray) = CUDADevice()
13+
1114
const FREE_STREAMS = CUDA.CuStream[]
1215
const STREAMS = CUDA.CuStream[]
1316
const STREAM_GC_THRESHOLD = Ref{Int}(16)

lib/CUDAKernels/test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
44
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
55
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
66
KernelGradients = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
911
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1012
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1113
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

lib/CUDAKernels/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Enzyme
44
using CUDA
55
using CUDAKernels
66
using Test
7+
using SparseArrays
78

89
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
910
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))

lib/ROCKernels/src/ROCKernels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import KernelAbstractions
99

1010
export ROCDevice
1111

12+
KernelAbstractions.get_device(::AMDGPU.ROCArray) = ROCDevice()
13+
14+
1215
const FREE_QUEUES = HSAQueue[]
1316
const QUEUES = HSAQueue[]
1417
const QUEUE_GC_THRESHOLD = Ref{Int}(16)

lib/ROCKernels/test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
44
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
55
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
66
KernelGradients = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
911
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1012
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1113
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/KernelAbstractions.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ export Device, GPU, CPU, Event, MultiEvent, NoneEvent
66
export async_copy!
77

88

9+
using LinearAlgebra
910
using MacroTools
11+
using SparseArrays
1012
using StaticArrays
1113
using Adapt
1214

@@ -336,6 +338,23 @@ abstract type GPU <: Device end
336338

337339
struct CPU <: Device end
338340

341+
342+
"""
343+
KernelAbstractions.get_device(A::AbstractArray)::KernelAbstractions.Device
344+
345+
Get a `KernelAbstractions.Device` instance suitable for array `A`.
346+
"""
347+
function get_device end
348+
349+
# Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.:
350+
get_device(A::AbstractArray) = get_device(parent(A))
351+
352+
get_device(A::AbstractSparseArray) = get_device(rowvals(A))
353+
get_device(A::Diagonal) = get_device(A.diag)
354+
get_device(A::Tridiagonal) = get_device(A.d)
355+
356+
get_device(::Array) = CPU()
357+
339358
include("nditeration.jl")
340359
using .NDIteration
341360
import .NDIteration: get

0 commit comments

Comments
 (0)