1
1
# GPUArrays.jl interface
2
2
3
+ import KernelAbstractions
4
+ import KernelAbstractions: Backend
3
5
4
6
#
5
7
# Device functionality
8
10
9
11
# # execution
10
12
11
- struct oneArrayBackend <: AbstractGPUBackend end
12
-
13
- struct oneKernelContext <: AbstractKernelContext end
13
+ struct oneArrayBackend <: Backend end
14
14
15
15
@inline function GPUArrays. launch_heuristic (:: oneArrayBackend , f:: F , args:: Vararg{Any,N} ;
16
16
elements:: Int , elements_per_thread:: Int ) where {F,N}
@@ -23,48 +23,6 @@ struct oneKernelContext <: AbstractKernelContext end
23
23
return (threads= items, blocks= 32 )
24
24
end
25
25
26
- function GPUArrays. gpu_call (:: oneArrayBackend , f, args, threads:: Int , blocks:: Int ;
27
- name:: Union{String,Nothing} )
28
- @oneapi items= threads groups= blocks name= name f (oneKernelContext (), args... )
29
- end
30
-
31
-
32
- # # on-device
33
-
34
- # indexing
35
-
36
- GPUArrays. blockidx (ctx:: oneKernelContext ) = oneAPI. get_group_id (0 )
37
- GPUArrays. blockdim (ctx:: oneKernelContext ) = oneAPI. get_local_size (0 )
38
- GPUArrays. threadidx (ctx:: oneKernelContext ) = oneAPI. get_local_id (0 )
39
- GPUArrays. griddim (ctx:: oneKernelContext ) = oneAPI. get_num_groups (0 )
40
-
41
- # math
42
-
43
- @inline GPUArrays. cos (ctx:: oneKernelContext , x) = oneAPI. cos (x)
44
- @inline GPUArrays. sin (ctx:: oneKernelContext , x) = oneAPI. sin (x)
45
- @inline GPUArrays. sqrt (ctx:: oneKernelContext , x) = oneAPI. sqrt (x)
46
- @inline GPUArrays. log (ctx:: oneKernelContext , x) = oneAPI. log (x)
47
-
48
- # memory
49
-
50
- @inline function GPUArrays. LocalMemory (:: oneKernelContext , :: Type{T} , :: Val{dims} , :: Val{id}
51
- ) where {T, dims, id}
52
- ptr = oneAPI. emit_localmemory (Val (id), T, Val (prod (dims)))
53
- oneDeviceArray (dims, LLVMPtr {T, onePI.AS.Local} (ptr))
54
- end
55
-
56
- # synchronization
57
-
58
- @inline GPUArrays. synchronize_threads (:: oneKernelContext ) = oneAPI. barrier ()
59
-
60
-
61
-
62
- #
63
- # Host abstractions
64
- #
65
-
66
- GPUArrays. backend (:: Type{<:oneArray} ) = oneArrayBackend ()
67
-
68
26
const GLOBAL_RNGs = Dict {ZeDevice,GPUArrays.RNG} ()
69
27
function GPUArrays. default_rng (:: Type{<:oneArray} )
70
28
dev = device ()
0 commit comments