Skip to content

Commit e7c51ef

Browse files
committed
adding necessary changes for KA transition for gpuarrays
1 parent fadcd8d commit e7c51ef

File tree

1 file changed

+3
-45
lines changed

1 file changed

+3
-45
lines changed

src/gpuarrays.jl

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# GPUArrays.jl interface
22

3+
import KernelAbstractions
4+
import KernelAbstractions: Backend
35

46
#
57
# Device functionality
@@ -8,9 +10,7 @@
810

911
## execution
1012

11-
struct oneArrayBackend <: AbstractGPUBackend end
12-
13-
struct oneKernelContext <: AbstractKernelContext end
13+
struct oneArrayBackend <: Backend end
1414

1515
@inline function GPUArrays.launch_heuristic(::oneArrayBackend, f::F, args::Vararg{Any,N};
1616
elements::Int, elements_per_thread::Int) where {F,N}
@@ -25,48 +25,6 @@ struct oneKernelContext <: AbstractKernelContext end
2525
return (threads=items, blocks=32)
2626
end
2727

28-
function GPUArrays.gpu_call(::oneArrayBackend, f, args, threads::Int, blocks::Int;
29-
name::Union{String,Nothing})
30-
@oneapi items=threads groups=blocks name=name f(oneKernelContext(), args...)
31-
end
32-
33-
34-
## on-device
35-
36-
# indexing
37-
38-
GPUArrays.blockidx(ctx::oneKernelContext) = oneAPI.get_group_id(0)
39-
GPUArrays.blockdim(ctx::oneKernelContext) = oneAPI.get_local_size(0)
40-
GPUArrays.threadidx(ctx::oneKernelContext) = oneAPI.get_local_id(0)
41-
GPUArrays.griddim(ctx::oneKernelContext) = oneAPI.get_num_groups(0)
42-
43-
# math
44-
45-
@inline GPUArrays.cos(ctx::oneKernelContext, x) = oneAPI.cos(x)
46-
@inline GPUArrays.sin(ctx::oneKernelContext, x) = oneAPI.sin(x)
47-
@inline GPUArrays.sqrt(ctx::oneKernelContext, x) = oneAPI.sqrt(x)
48-
@inline GPUArrays.log(ctx::oneKernelContext, x) = oneAPI.log(x)
49-
50-
# memory
51-
52-
@inline function GPUArrays.LocalMemory(::oneKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
53-
) where {T, dims, id}
54-
ptr = oneAPI.emit_localmemory(Val(id), T, Val(prod(dims)))
55-
oneDeviceArray(dims, LLVMPtr{T, onePI.AS.Local}(ptr))
56-
end
57-
58-
# synchronization
59-
60-
@inline GPUArrays.synchronize_threads(::oneKernelContext) = oneAPI.barrier()
61-
62-
63-
64-
#
65-
# Host abstractions
66-
#
67-
68-
GPUArrays.backend(::Type{<:oneArray}) = oneArrayBackend()
69-
7028
const GLOBAL_RNGs = Dict{ZeDevice,GPUArrays.RNG}()
7129
function GPUArrays.default_rng(::Type{<:oneArray})
7230
dev = device()

0 commit comments

Comments
 (0)