Skip to content

Commit fc151cd

Browse files
committed
CuArrayBackend -> CUDABackend
1 parent 9590be3 commit fc151cd

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

src/CUDA.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ include("compiler/execution.jl")
8383
include("compiler/exceptions.jl")
8484
include("compiler/reflection.jl")
8585

86+
# KernelAbstractions
87+
include("CUDAKernels.jl")
88+
import .CUDAKernels: CUDABackend, KA
89+
export CUDABackend
90+
8691
# array implementation
8792
include("gpuarrays.jl")
8893
include("utilities.jl")
@@ -111,6 +116,9 @@ export CUBLAS, CUSPARSE, CUSOLVER, CUFFT, CURAND
111116
const has_cusolvermg = CUSOLVER.has_cusolvermg
112117
export has_cusolvermg
113118

119+
# KA Backend Definition
120+
KA.get_backend(::CUSPARSE.AbstractCuSparseArray) = CUDABackend()
121+
114122
# random depends on CURAND
115123
include("random.jl")
116124

@@ -119,11 +127,6 @@ include("../lib/nvml/NVML.jl")
119127
const has_nvml = NVML.has_nvml
120128
export NVML, has_nvml
121129

122-
# KernelAbstractions
123-
include("CUDAKernels.jl")
124-
import .CUDAKernels: CUDABackend
125-
export CUDABackend
126-
127130
# StaticArrays is still a direct dependency, so directly include the extension
128131
include("../ext/StaticArraysExt.jl")
129132

src/CUDAKernels.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ KA.zeros(::CUDABackend, ::Type{T}, dims::Tuple) where T = CUDA.zeros(T, dims)
2525
KA.ones(::CUDABackend, ::Type{T}, dims::Tuple) where T = CUDA.ones(T, dims)
2626

2727
KA.get_backend(::CuArray) = CUDABackend()
28-
KA.get_backend(::CUSPARSE.AbstractCuSparseArray) = CUDABackend()
2928
KA.synchronize(::CUDABackend) = synchronize()
3029

3130
Adapt.adapt_storage(::CUDABackend, a::Array) = Adapt.adapt(CuArray, a)

src/gpuarrays.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
# GPUArrays.jl interface
22

3-
import KernelAbstractions
4-
import KernelAbstractions: Backend
5-
63
#
74
# Device functionality
85
#
96

107

118
## execution
129

13-
struct CuArrayBackend <: Backend end
1410

15-
@inline function GPUArrays.launch_heuristic(::CuArrayBackend, f::F, args::Vararg{Any,N};
11+
@inline function GPUArrays.launch_heuristic(::CUDABackend, f::F, args::Vararg{Any,N};
1612
elements::Int, elements_per_thread::Int) where {F,N}
17-
kernel = @cuda launch=false f(CuKernelContext(), args...)
13+
14+
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, nothing,
15+
nothing)
16+
17+
# this might not be the final context, since we may tune the workgroupsize
18+
ctx = KA.mkcontext(obj, ndrange, iterspace)
19+
kernel = @cuda launch=false f(ctx, args...)
1820

1921
# launching many large blocks) lowers performance, as observed with broadcast, so cap
2022
# the block size if we don't have a grid-stride kernel (which would keep the grid small)

0 commit comments

Comments
 (0)