Skip to content

Commit 1cc6089

Browse files
committed
updating ctx for launch_heuristic
1 parent 9e50975 commit 1cc6089

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/gpuarrays.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ import KernelAbstractions: Backend
1010

1111
## execution
1212

13-
struct oneArrayBackend <: Backend end
14-
15-
@inline function GPUArrays.launch_heuristic(::oneArrayBackend, f::F, args::Vararg{Any,N};
13+
@inline function GPUArrays.launch_heuristic(::oneAPIBackend, f::F, args::Vararg{Any,N};
1614
elements::Int, elements_per_thread::Int) where {F,N}
17-
kernel = @oneapi launch=false f(oneKernelContext(), args...)
15+
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, nothing,
16+
nothing)
17+
18+
# this might not be the final context, since we may tune the workgroupsize
19+
ctx = KA.mkcontext(obj, ndrange, iterspace)
20+
21+
kernel = @oneapi launch=false f(ctx, args...)
1822

1923
items = launch_configuration(kernel)
2024
# XXX: how many groups is a good number? the API doesn't tell us.

src/oneAPI.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ end
6767
# integrations and specialized functionality
6868
include("broadcast.jl")
6969
include("mapreduce.jl")
70-
include("gpuarrays.jl")
71-
include("random.jl")
72-
include("utils.jl")
73-
7470
include("oneAPIKernels.jl")
7571
import .oneAPIKernels: oneAPIBackend
7672
export oneAPIBackend
7773

74+
include("gpuarrays.jl")
75+
include("random.jl")
76+
include("utils.jl")
77+
7878
function __init__()
7979
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
8080
precompiling && return

0 commit comments

Comments
 (0)