Skip to content

Commit 7cb79cc

Browse files
authored
Merge pull request #431 from JuliaGPU/tb/launch_config
Roll our own launch configuration
2 parents 688a2be + fd03f9a commit 7cb79cc

File tree

10 files changed

+69
-15
lines changed

10 files changed

+69
-15
lines changed

lib/level-zero/driver.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,18 @@ function ipc_properties(drv::ZeDriver)
100100
)
101101
end
102102

103-
# FIXME: throws ZE_RESULT_ERROR_UNSUPPORTED_FEATURE
104103
function extension_properties(drv::ZeDriver)
105104
count_ref = Ref{UInt32}(0)
106105
zeDriverGetExtensionProperties(drv, count_ref, C_NULL)
107106

108107
all_props = Vector{ze_driver_extension_properties_t}(undef, count_ref[])
109108
zeDriverGetExtensionProperties(drv, count_ref, all_props)
110109

111-
return [(name=String([props.name[1:findfirst(isequal(0), props.name)-1]...]),
112-
version=Int(props.version),
113-
) for props in all_props[1:count_ref[]]]
110+
extensions = Dict{String,VersionNumber}()
111+
for prop in all_props[1:count_ref[]]
112+
name = String(UInt8[prop.name[1:findfirst(isequal(0), prop.name)-1]...])
113+
version = unmake_version(prop.version)
114+
extensions[name] = version
115+
end
116+
return extensions
114117
end

lib/level-zero/memory.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ function device_alloc(ctx::ZeContext, dev::ZeDevice, bytesize::Integer,
5050
flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE,
5151
))
5252
GC.@preserve relaxed_allocation_ref begin
53-
desc_ref = if bytesize > properties(dev).maxMemAllocSize
54-
pNext = Base.unsafe_convert(Ptr{Cvoid}, relaxed_allocation_ref)
55-
Ref(ze_device_mem_alloc_desc_t(; flags, ordinal, pNext))
53+
if bytesize > properties(dev).maxMemAllocSize
54+
desc_ref = Ref(ze_device_mem_alloc_desc_t(; flags, ordinal))
55+
link_extensions(desc_ref, relaxed_allocation_ref)
5656
else
57-
Ref(ze_device_mem_alloc_desc_t(; flags, ordinal))
57+
desc_ref = Ref(ze_device_mem_alloc_desc_t(; flags, ordinal))
5858
end
5959

6060
ptr_ref = Ref{Ptr{Cvoid}}()

lib/level-zero/module.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,16 @@ export properties
234234

235235
function properties(kernel::ZeKernel)
236236
props_ref = Ref(ze_kernel_properties_t())
237+
preferred_group_size_props_ref = Ref(ze_kernel_preferred_group_size_properties_t())
238+
link_extensions(props_ref, preferred_group_size_props_ref)
239+
if haskey(oneL0.extension_properties(kernel.mod.context.driver),
240+
"ZE_extension_kernel_max_group_size_properties")
241+
# TODO: memoize
242+
max_group_size_props_ref = Ref(ze_kernel_max_group_size_properties_ext_t())
243+
link_extensions(preferred_group_size_props_ref, max_group_size_props_ref)
244+
else
245+
max_group_size_props_ref = nothing
246+
end
237247
zeKernelGetProperties(kernel, props_ref)
238248

239249
props = props_ref[]
@@ -251,6 +261,9 @@ function properties(kernel::ZeKernel)
251261
spillMemSize=Int(props.spillMemSize),
252262
kernel_uuid=Base.UUID(reinterpret(UInt128, [props.uuid.kid...])[1]),
253263
module_uuid=Base.UUID(reinterpret(UInt128, [props.uuid.mid...])[1]),
264+
preferredGroupSize=Int(preferred_group_size_props_ref[].preferredMultiple),
265+
maxGroupSize=max_group_size_props_ref === nothing ? missing :
266+
Int(max_group_size_props_ref[].maxGroupSize)
254267
)
255268
end
256269

lib/level-zero/oneL0.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ zeroinit(::Type{_ze_native_kernel_uuid_t}) =
6969
zeroinit(::Type{ze_kernel_uuid_t}) =
7070
ze_kernel_uuid_t(ntuple(_->zero(UInt8), 16), ntuple(_->zero(UInt8), 16))
7171

72+
# link extension objects in pNext
73+
function link_extensions(refs...)
74+
length(refs) >= 2 || return
75+
for (parent, child) in zip(refs[1:end-1], refs[2:end])
76+
pNext = Base.unsafe_convert(Ptr{Cvoid}, child)
77+
typ = eltype(parent)
78+
@assert fieldnames(typ)[2] == :pNext
79+
field = Base.unsafe_convert(Ptr{Cvoid}, parent) + fieldoffset(typ, 2)
80+
field = convert(Ptr{Ptr{Cvoid}}, field)
81+
unsafe_store!(field, pNext)
82+
end
83+
return
84+
end
85+
7286
# core wrappers
7387
include("error.jl")
7488
include("common.jl")

res/patches/.dummy

Whitespace-only changes.

src/compiler/execution.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,33 @@ struct HostKernel{F,TT} <: AbstractKernel{F,TT}
161161
fun::ZeKernel
162162
end
163163

164+
function launch_configuration(kernel::HostKernel{F,TT}) where {F,TT}
165+
# XXX: have the user pass in a global size to clamp against
166+
# maxGroupSizeX/Y/Z?
167+
168+
# XXX: shrink until a multiple of preferredGroupSize?
169+
170+
# once the MAX_GROUP_SIZE extension is implemented, we can use it here
171+
kernel_props = oneL0.properties(kernel.fun)
172+
if kernel_props.maxGroupSize !== missing
173+
return kernel_props.maxGroupSize
174+
end
175+
176+
# otherwise, we'd use `zeKernelSuggestGroupSize` but it's been observed
177+
# to return really bad configs (JuliaGPU/oneAPI.jl#430)
178+
179+
# so instead, calculate it ourselves based on the device properties
180+
dev = kernel.fun.mod.device
181+
compute_props = oneL0.compute_properties(dev)
182+
max_size = compute_props.maxTotalGroupSize
183+
## when the kernel uses many registers (which we can't query without
184+
## extensions that landed _after_ MAX_GROUP_SIZE, so don't bother)
185+
## the groupsize should be halved
186+
group_size = max_size ÷ 2
187+
188+
return group_size
189+
end
190+
164191

165192
## host-side API
166193

src/gpuarrays.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ struct oneKernelContext <: AbstractKernelContext end
1616
elements::Int, elements_per_thread::Int) where {F,N}
1717
kernel = @oneapi launch=false f(oneKernelContext(), args...)
1818

19-
items = suggest_groupsize(kernel.fun, elements).x
20-
# XXX: the z dimension of the suggested group size is often non-zero.
21-
# preserve this in GPUArrays?
19+
items = launch_configuration(kernel)
2220
# XXX: how many groups is a good number? the API doesn't tell us.
2321
# measured on a low-end IGP, 32 blocks seems like a good sweet spot.
2422
# note that this only matters for grid-stride kernels, like broadcast.

src/mapreduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function GPUArrays.mapreducedim!(f::F, op::OP, R::oneWrappedArray{T},
146146
kernel_args = kernel_convert.(args)
147147
kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
148148
kernel = zefunction(partial_mapreduce_device, kernel_tt)
149-
reduce_items = compute_items(suggest_groupsize(kernel.fun, wanted_items).x)
149+
reduce_items = launch_configuration(kernel)
150150

151151
# how many groups should we launch?
152152
#

src/oneAPIKernels.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize
9090

9191
# figure out the optimal workgroupsize automatically
9292
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
93-
items = oneAPI.suggest_groupsize(kernel.fun, prod(ndrange)).x
94-
# XXX: the z dimension of the suggested group size is often non-zero. use this?
93+
items = oneAPI.launch_configuration(kernel)
9594
workgroupsize = threads_to_workgroupsize(items, ndrange)
9695
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
9796
ctx = KA.mkcontext(obj, ndrange, iterspace)

test/level-zero.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ api_version(drv)
2020

2121
properties(drv)
2222
ipc_properties(drv)
23-
#extension_properties(drv)
23+
extension_properties(drv)
2424

2525
end
2626

0 commit comments

Comments
 (0)