1
1
# GPUArrays.jl interface
2
2
3
+ import KernelAbstractions
4
+ import KernelAbstractions: Backend
3
5
4
6
#
5
7
# Device functionality
8
10
9
11
# # execution
10
12
11
- struct oneArrayBackend <: AbstractGPUBackend end
12
-
13
- struct oneKernelContext <: AbstractKernelContext end
13
+ struct oneArrayBackend <: Backend end
14
14
15
15
@inline function GPUArrays. launch_heuristic (:: oneArrayBackend , f:: F , args:: Vararg{Any,N} ;
16
16
elements:: Int , elements_per_thread:: Int ) where {F,N}
@@ -25,48 +25,6 @@ struct oneKernelContext <: AbstractKernelContext end
25
25
return (threads= items, blocks= 32 )
26
26
end
27
27
28
- function GPUArrays. gpu_call (:: oneArrayBackend , f, args, threads:: Int , blocks:: Int ;
29
- name:: Union{String,Nothing} )
30
- @oneapi items= threads groups= blocks name= name f (oneKernelContext (), args... )
31
- end
32
-
33
-
34
- # # on-device
35
-
36
- # indexing
37
-
38
- GPUArrays. blockidx (ctx:: oneKernelContext ) = oneAPI. get_group_id (0 )
39
- GPUArrays. blockdim (ctx:: oneKernelContext ) = oneAPI. get_local_size (0 )
40
- GPUArrays. threadidx (ctx:: oneKernelContext ) = oneAPI. get_local_id (0 )
41
- GPUArrays. griddim (ctx:: oneKernelContext ) = oneAPI. get_num_groups (0 )
42
-
43
- # math
44
-
45
- @inline GPUArrays. cos (ctx:: oneKernelContext , x) = oneAPI. cos (x)
46
- @inline GPUArrays. sin (ctx:: oneKernelContext , x) = oneAPI. sin (x)
47
- @inline GPUArrays. sqrt (ctx:: oneKernelContext , x) = oneAPI. sqrt (x)
48
- @inline GPUArrays. log (ctx:: oneKernelContext , x) = oneAPI. log (x)
49
-
50
- # memory
51
-
52
- @inline function GPUArrays. LocalMemory (:: oneKernelContext , :: Type{T} , :: Val{dims} , :: Val{id}
53
- ) where {T, dims, id}
54
- ptr = oneAPI. emit_localmemory (Val (id), T, Val (prod (dims)))
55
- oneDeviceArray (dims, LLVMPtr {T, onePI.AS.Local} (ptr))
56
- end
57
-
58
- # synchronization
59
-
60
- @inline GPUArrays. synchronize_threads (:: oneKernelContext ) = oneAPI. barrier ()
61
-
62
-
63
-
64
- #
65
- # Host abstractions
66
- #
67
-
68
- GPUArrays. backend (:: Type{<:oneArray} ) = oneArrayBackend ()
69
-
70
28
const GLOBAL_RNGs = Dict {ZeDevice,GPUArrays.RNG} ()
71
29
function GPUArrays. default_rng (:: Type{<:oneArray} )
72
30
dev = device ()
0 commit comments