Skip to content

Commit e0303f2

Browse files
committed
Inspect extended kernel properties.
1 parent 79aea1d commit e0303f2

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

lib/level-zero/memory.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ function device_alloc(ctx::ZeContext, dev::ZeDevice, bytesize::Integer,
5050
flags = ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE,
5151
))
5252
GC.@preserve relaxed_allocation_ref begin
53-
desc_ref = if bytesize > properties(dev).maxMemAllocSize
54-
pNext = Base.unsafe_convert(Ptr{Cvoid}, relaxed_allocation_ref)
55-
Ref(ze_device_mem_alloc_desc_t(; flags, ordinal, pNext))
53+
if bytesize > properties(dev).maxMemAllocSize
54+
desc_ref = Ref(ze_device_mem_alloc_desc_t(; flags, ordinal))
55+
link_extensions(desc_ref, relaxed_allocation_ref)
5656
else
57-
Ref(ze_device_mem_alloc_desc_t(; flags, ordinal))
57+
desc_ref = Ref(ze_device_mem_alloc_desc_t(; flags, ordinal))
5858
end
5959

6060
ptr_ref = Ref{Ptr{Cvoid}}()

lib/level-zero/module.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,16 @@ export properties
234234

235235
function properties(kernel::ZeKernel)
236236
props_ref = Ref(ze_kernel_properties_t())
237+
preferred_group_size_props_ref = Ref(ze_kernel_preferred_group_size_properties_t())
238+
link_extensions(props_ref, preferred_group_size_props_ref)
239+
if haskey(oneL0.extension_properties(kernel.mod.context.driver),
240+
"ZE_extension_kernel_max_group_size_properties")
241+
# TODO: memoize
242+
max_group_size_props_ref = Ref(ze_kernel_max_group_size_properties_ext_t())
243+
link_extensions(preferred_group_size_props_ref, max_group_size_props_ref)
244+
else
245+
max_group_size_props_ref = nothing
246+
end
237247
zeKernelGetProperties(kernel, props_ref)
238248

239249
props = props_ref[]
@@ -251,6 +261,9 @@ function properties(kernel::ZeKernel)
251261
spillMemSize=Int(props.spillMemSize),
252262
kernel_uuid=Base.UUID(reinterpret(UInt128, [props.uuid.kid...])[1]),
253263
module_uuid=Base.UUID(reinterpret(UInt128, [props.uuid.mid...])[1]),
264+
preferredGroupSize=Int(preferred_group_size_props_ref[].preferredMultiple),
265+
maxGroupSize=max_group_size_props_ref === nothing ? missing :
266+
Int(max_group_size_props_ref[].maxGroupSize)
254267
)
255268
end
256269

lib/level-zero/oneL0.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ zeroinit(::Type{_ze_native_kernel_uuid_t}) =
6969
zeroinit(::Type{ze_kernel_uuid_t}) =
7070
ze_kernel_uuid_t(ntuple(_->zero(UInt8), 16), ntuple(_->zero(UInt8), 16))
7171

72+
# link extension objects in pNext
73+
function link_extensions(refs...)
74+
length(refs) >= 2 || return
75+
for (parent, child) in zip(refs[1:end-1], refs[2:end])
76+
pNext = Base.unsafe_convert(Ptr{Cvoid}, child)
77+
typ = eltype(parent)
78+
@assert fieldnames(typ)[2] == :pNext
79+
field = Base.unsafe_convert(Ptr{Cvoid}, parent) + fieldoffset(typ, 2)
80+
field = convert(Ptr{Ptr{Cvoid}}, field)
81+
unsafe_store!(field, pNext)
82+
end
83+
return
84+
end
85+
7286
# core wrappers
7387
include("error.jl")
7488
include("common.jl")

0 commit comments

Comments
 (0)