Skip to content

Commit fd03f9a

Browse files
committed
Roll our own launch configuration API.
1 parent e0303f2 commit fd03f9a

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

src/compiler/execution.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,33 @@ struct HostKernel{F,TT} <: AbstractKernel{F,TT}
161161
fun::ZeKernel
162162
end
163163

164+
function launch_configuration(kernel::HostKernel{F,TT}) where {F,TT}
165+
# XXX: have the user pass in a global size to clamp against
166+
# maxGroupSizeX/Y/Z?
167+
168+
# XXX: shrink until a multiple of preferredGroupSize?
169+
170+
# once the MAX_GROUP_SIZE extension is implemented, we can use it here
171+
kernel_props = oneL0.properties(kernel.fun)
172+
if kernel_props.maxGroupSize !== missing
173+
return kernel_props.maxGroupSize
174+
end
175+
176+
# otherwise, we'd use `zeKernelSuggestGroupSize` but it's been observed
177+
# to return really bad configs (JuliaGPU/oneAPI.jl#430)
178+
179+
# so instead, calculate it ourselves based on the device properties
180+
dev = kernel.fun.mod.device
181+
compute_props = oneL0.compute_properties(dev)
182+
max_size = compute_props.maxTotalGroupSize
183+
## when the kernel uses many registers (which we can't query without
184+
## extensions that landed _after_ MAX_GROUP_SIZE, so don't bother)
185+
## the groupsize should be halved
186+
group_size = max_size ÷ 2
187+
188+
return group_size
189+
end
190+
164191

165192
## host-side API
166193

src/gpuarrays.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ struct oneKernelContext <: AbstractKernelContext end
1616
elements::Int, elements_per_thread::Int) where {F,N}
1717
kernel = @oneapi launch=false f(oneKernelContext(), args...)
1818

19-
items = suggest_groupsize(kernel.fun, elements).x
20-
# XXX: the z dimension of the suggested group size is often non-zero.
21-
# preserve this in GPUArrays?
19+
items = launch_configuration(kernel)
2220
# XXX: how many groups is a good number? the API doesn't tell us.
2321
# measured on a low-end IGP, 32 blocks seems like a good sweet spot.
2422
# note that this only matters for grid-stride kernels, like broadcast.

src/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::oneWrappedArray{T},
146146
kernel_args = kernel_convert.(args)
147147
kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
148148
kernel = zefunction(partial_mapreduce_device, kernel_tt)
149-
reduce_items = compute_items(suggest_groupsize(kernel.fun, wanted_items).x)
149+
reduce_items = launch_configuration(kernel)
150150

151151
# how many groups should we launch?
152152
#

src/oneAPIKernels.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize
9090

9191
# figure out the optimal workgroupsize automatically
9292
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
93-
items = oneAPI.suggest_groupsize(kernel.fun, prod(ndrange)).x
94-
# XXX: the z dimension of the suggested group size is often non-zero. use this?
93+
items = oneAPI.launch_configuration(kernel)
9594
workgroupsize = threads_to_workgroupsize(items, ndrange)
9695
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
9796
ctx = KA.mkcontext(obj, ndrange, iterspace)

0 commit comments

Comments
 (0)