Skip to content

Commit 25cf3e0

Browse files
maleadtleios
andauthored
Adapt to GPUArrays.jl transition to KernelAbstractions.jl. (#475)
Co-authored-by: James Schloss <jrs.schloss@gmail.com>
1 parent 90a44be commit 25cf3e0

File tree

2 files changed

+1
-66
lines changed

2 files changed

+1
-66
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
3030
Adapt = "4"
3131
CEnum = "0.4, 0.5"
3232
ExprTools = "0.1"
33-
GPUArrays = "10"
33+
GPUArrays = "11"
3434
GPUCompiler = "0.23, 0.24, 0.25, 0.26, 0.27, 1"
3535
KernelAbstractions = "0.9.1"
3636
LLVM = "6, 7, 8, 9"

src/gpuarrays.jl

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,5 @@
11
# GPUArrays.jl interface
22

3-
4-
#
5-
# Device functionality
6-
#
7-
8-
9-
## execution
10-
11-
struct oneArrayBackend <: AbstractGPUBackend end
12-
13-
struct oneKernelContext <: AbstractKernelContext end
14-
15-
@inline function GPUArrays.launch_heuristic(::oneArrayBackend, f::F, args::Vararg{Any,N};
16-
elements::Int, elements_per_thread::Int) where {F,N}
17-
kernel = @oneapi launch=false f(oneKernelContext(), args...)
18-
19-
items = launch_configuration(kernel)
20-
# XXX: how many groups is a good number? the API doesn't tell us.
21-
# measured on a low-end IGP, 32 blocks seems like a good sweet spot.
22-
# note that this only matters for grid-stride kernels, like broadcast.
23-
return (threads=items, blocks=32)
24-
end
25-
26-
function GPUArrays.gpu_call(::oneArrayBackend, f, args, threads::Int, blocks::Int;
27-
name::Union{String,Nothing})
28-
@oneapi items=threads groups=blocks name=name f(oneKernelContext(), args...)
29-
end
30-
31-
32-
## on-device
33-
34-
# indexing
35-
36-
GPUArrays.blockidx(ctx::oneKernelContext) = oneAPI.get_group_id(0)
37-
GPUArrays.blockdim(ctx::oneKernelContext) = oneAPI.get_local_size(0)
38-
GPUArrays.threadidx(ctx::oneKernelContext) = oneAPI.get_local_id(0)
39-
GPUArrays.griddim(ctx::oneKernelContext) = oneAPI.get_num_groups(0)
40-
41-
# math
42-
43-
@inline GPUArrays.cos(ctx::oneKernelContext, x) = oneAPI.cos(x)
44-
@inline GPUArrays.sin(ctx::oneKernelContext, x) = oneAPI.sin(x)
45-
@inline GPUArrays.sqrt(ctx::oneKernelContext, x) = oneAPI.sqrt(x)
46-
@inline GPUArrays.log(ctx::oneKernelContext, x) = oneAPI.log(x)
47-
48-
# memory
49-
50-
@inline function GPUArrays.LocalMemory(::oneKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
51-
) where {T, dims, id}
52-
ptr = oneAPI.emit_localmemory(Val(id), T, Val(prod(dims)))
53-
oneDeviceArray(dims, LLVMPtr{T, onePI.AS.Local}(ptr))
54-
end
55-
56-
# synchronization
57-
58-
@inline GPUArrays.synchronize_threads(::oneKernelContext) = oneAPI.barrier()
59-
60-
61-
62-
#
63-
# Host abstractions
64-
#
65-
66-
GPUArrays.backend(::Type{<:oneArray}) = oneArrayBackend()
67-
683
const GLOBAL_RNGs = Dict{ZeDevice,GPUArrays.RNG}()
694
function GPUArrays.default_rng(::Type{<:oneArray})
705
dev = device()

0 commit comments

Comments
 (0)