Skip to content

Commit d357bce

Browse files
committed
adding necessary changes for KA transition for gpuarrays
1 parent c3fc211 commit d357bce

File tree

1 file changed

+3
-45
lines changed

1 file changed

+3
-45
lines changed

src/gpuarrays.jl

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# GPUArrays.jl interface
22

3+
import KernelAbstractions
4+
import KernelAbstractions: Backend
35

46
#
57
# Device functionality
@@ -8,9 +10,7 @@
810

911
## execution
1012

11-
struct oneArrayBackend <: AbstractGPUBackend end
12-
13-
struct oneKernelContext <: AbstractKernelContext end
13+
struct oneArrayBackend <: Backend end
1414

1515
@inline function GPUArrays.launch_heuristic(::oneArrayBackend, f::F, args::Vararg{Any,N};
1616
elements::Int, elements_per_thread::Int) where {F,N}
@@ -23,48 +23,6 @@ struct oneKernelContext <: AbstractKernelContext end
2323
return (threads=items, blocks=32)
2424
end
2525

26-
function GPUArrays.gpu_call(::oneArrayBackend, f, args, threads::Int, blocks::Int;
27-
name::Union{String,Nothing})
28-
@oneapi items=threads groups=blocks name=name f(oneKernelContext(), args...)
29-
end
30-
31-
32-
## on-device
33-
34-
# indexing
35-
36-
GPUArrays.blockidx(ctx::oneKernelContext) = oneAPI.get_group_id(0)
37-
GPUArrays.blockdim(ctx::oneKernelContext) = oneAPI.get_local_size(0)
38-
GPUArrays.threadidx(ctx::oneKernelContext) = oneAPI.get_local_id(0)
39-
GPUArrays.griddim(ctx::oneKernelContext) = oneAPI.get_num_groups(0)
40-
41-
# math
42-
43-
@inline GPUArrays.cos(ctx::oneKernelContext, x) = oneAPI.cos(x)
44-
@inline GPUArrays.sin(ctx::oneKernelContext, x) = oneAPI.sin(x)
45-
@inline GPUArrays.sqrt(ctx::oneKernelContext, x) = oneAPI.sqrt(x)
46-
@inline GPUArrays.log(ctx::oneKernelContext, x) = oneAPI.log(x)
47-
48-
# memory
49-
50-
@inline function GPUArrays.LocalMemory(::oneKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
51-
) where {T, dims, id}
52-
ptr = oneAPI.emit_localmemory(Val(id), T, Val(prod(dims)))
53-
oneDeviceArray(dims, LLVMPtr{T, onePI.AS.Local}(ptr))
54-
end
55-
56-
# synchronization
57-
58-
@inline GPUArrays.synchronize_threads(::oneKernelContext) = oneAPI.barrier()
59-
60-
61-
62-
#
63-
# Host abstractions
64-
#
65-
66-
GPUArrays.backend(::Type{<:oneArray}) = oneArrayBackend()
67-
6826
const GLOBAL_RNGs = Dict{ZeDevice,GPUArrays.RNG}()
6927
function GPUArrays.default_rng(::Type{<:oneArray})
7028
dev = device()

0 commit comments

Comments
 (0)