|
1 | 1 | # GPUArrays.jl interface
|
2 | 2 |
|
3 |
| - |
4 |
| -# |
5 |
| -# Device functionality |
6 |
| -# |
7 |
| - |
8 |
| - |
9 |
| -## execution |
10 |
| - |
11 |
| -struct oneArrayBackend <: AbstractGPUBackend end |
12 |
| - |
13 |
| -struct oneKernelContext <: AbstractKernelContext end |
14 |
| - |
15 |
| -@inline function GPUArrays.launch_heuristic(::oneArrayBackend, f::F, args::Vararg{Any,N}; |
16 |
| - elements::Int, elements_per_thread::Int) where {F,N} |
17 |
| - kernel = @oneapi launch=false f(oneKernelContext(), args...) |
18 |
| - |
19 |
| - items = launch_configuration(kernel) |
20 |
| - # XXX: how many groups is a good number? the API doesn't tell us. |
21 |
| - # measured on a low-end IGP, 32 blocks seems like a good sweet spot. |
22 |
| - # note that this only matters for grid-stride kernels, like broadcast. |
23 |
| - return (threads=items, blocks=32) |
24 |
| -end |
25 |
| - |
26 |
| -function GPUArrays.gpu_call(::oneArrayBackend, f, args, threads::Int, blocks::Int; |
27 |
| - name::Union{String,Nothing}) |
28 |
| - @oneapi items=threads groups=blocks name=name f(oneKernelContext(), args...) |
29 |
| -end |
30 |
| - |
31 |
| - |
32 |
| -## on-device |
33 |
| - |
34 |
| -# indexing |
35 |
| - |
36 |
| -GPUArrays.blockidx(ctx::oneKernelContext) = oneAPI.get_group_id(0) |
37 |
| -GPUArrays.blockdim(ctx::oneKernelContext) = oneAPI.get_local_size(0) |
38 |
| -GPUArrays.threadidx(ctx::oneKernelContext) = oneAPI.get_local_id(0) |
39 |
| -GPUArrays.griddim(ctx::oneKernelContext) = oneAPI.get_num_groups(0) |
40 |
| - |
41 |
| -# math |
42 |
| - |
43 |
| -@inline GPUArrays.cos(ctx::oneKernelContext, x) = oneAPI.cos(x) |
44 |
| -@inline GPUArrays.sin(ctx::oneKernelContext, x) = oneAPI.sin(x) |
45 |
| -@inline GPUArrays.sqrt(ctx::oneKernelContext, x) = oneAPI.sqrt(x) |
46 |
| -@inline GPUArrays.log(ctx::oneKernelContext, x) = oneAPI.log(x) |
47 |
| - |
48 |
| -# memory |
49 |
| - |
50 |
| -@inline function GPUArrays.LocalMemory(::oneKernelContext, ::Type{T}, ::Val{dims}, ::Val{id} |
51 |
| - ) where {T, dims, id} |
52 |
| - ptr = oneAPI.emit_localmemory(Val(id), T, Val(prod(dims))) |
53 |
| - oneDeviceArray(dims, LLVMPtr{T, onePI.AS.Local}(ptr)) |
54 |
| -end |
55 |
| - |
56 |
| -# synchronization |
57 |
| - |
58 |
| -@inline GPUArrays.synchronize_threads(::oneKernelContext) = oneAPI.barrier() |
59 |
| - |
60 |
| - |
61 |
| - |
62 |
| -# |
63 |
| -# Host abstractions |
64 |
| -# |
65 |
| - |
66 |
| -GPUArrays.backend(::Type{<:oneArray}) = oneArrayBackend() |
67 |
| - |
68 | 3 | const GLOBAL_RNGs = Dict{ZeDevice,GPUArrays.RNG}()
|
69 | 4 | function GPUArrays.default_rng(::Type{<:oneArray})
|
70 | 5 | dev = device()
|
|
0 commit comments