Skip to content

Commit e8b4bd1

Browse files
committed
Metal: Emit the type name for indexing arguments.
Otherwise validation fails a range check.
1 parent 27f9fcf commit e8b4bd1

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

src/metal.jl

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -339,33 +339,41 @@ for intr in [
339339
"dispatch_quadgroups_per_threadgroup", "dispatch_simdgroups_per_threadgroup",
340340
"quadgroup_index_in_threadgroup", "quadgroups_per_threadgroup",
341341
"simdgroup_index_in_threadgroup", "simdgroups_per_threadgroup",
342-
"thread_index_in_quadgroup", "thread_index_in_simdgroup", "thread_index_in_threadgroup",
343-
"thread_execution_width", "threads_per_simdgroup"],
344-
(intr_typ, air_typ, julia_typ) in [
345-
("i32", "uint", UInt32),
346-
("i16", "ushort", UInt16),
342+
"thread_index_in_quadgroup", "thread_index_in_simdgroup",
343+
"thread_index_in_threadgroup", "thread_execution_width", "threads_per_simdgroup"],
344+
(llvm_typ, julia_typ) in [
345+
("i32", UInt32),
346+
("i16", UInt16),
347347
]
348-
push!(kernel_intrinsics,
349-
"julia.air.$intr.$intr_typ" =>
350-
(air_intr="$intr.$air_typ", air_typ, air_name=intr, julia_typ))
348+
push!(kernel_intrinsics, "julia.air.$intr.$llvm_typ" => (name=intr, typ=julia_typ))
351349
end
352350
for intr in [
353351
"dispatch_threads_per_threadgroup",
354352
"grid_origin", "grid_size",
355353
"thread_position_in_grid", "thread_position_in_threadgroup",
356354
"threadgroup_position_in_grid", "threadgroups_per_grid",
357355
"threads_per_grid", "threads_per_threadgroup"],
358-
(intr_typ, air_typ, julia_typ) in [
359-
("i32", "uint", UInt32),
360-
("v2i32", "uint2", NTuple{2, VecElement{UInt32}}),
361-
("v3i32", "uint3", NTuple{3, VecElement{UInt32}}),
362-
("i16", "ushort", UInt16),
363-
("v2i16", "ushort2", NTuple{2, VecElement{UInt16}}),
364-
("v3i16", "ushort3", NTuple{3, VecElement{UInt16}}),
356+
(llvm_typ, julia_typ) in [
357+
("i32", UInt32),
358+
("v2i32", NTuple{2, VecElement{UInt32}}),
359+
("v3i32", NTuple{3, VecElement{UInt32}}),
360+
("i16", UInt16),
361+
("v2i16", NTuple{2, VecElement{UInt16}}),
362+
("v3i16", NTuple{3, VecElement{UInt16}}),
365363
]
366-
push!(kernel_intrinsics,
367-
"julia.air.$intr.$intr_typ" =>
368-
(air_intr="$intr.$air_typ", air_typ, air_name=intr, julia_typ))
364+
push!(kernel_intrinsics, "julia.air.$intr.$llvm_typ" => (name=intr, typ=julia_typ))
365+
end
366+
367+
function argument_type_name(typ)
368+
if typ isa LLVM.IntegerType && width(typ) == 16
369+
"ushort"
370+
elseif typ isa LLVM.IntegerType && width(typ) == 32
371+
"uint"
372+
elseif typ isa LLVM.VectorType
373+
argument_type_name(eltype(typ)) * string(Int(size(typ)))
374+
else
375+
error("Cannot encode unknown type `$typ`")
376+
end
369377
end
370378

371379
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
@@ -414,7 +422,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
414422
new_param_types = LLVMType[parameters(ft)...]
415423

416424
for intr_fn in used_intrinsics
417-
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].julia_typ; ctx)
425+
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ; ctx)
418426
push!(new_param_types, llvm_typ)
419427
end
420428
new_ft = LLVM.FunctionType(LLVM.return_type(ft), new_param_types)
@@ -424,7 +432,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
424432
LLVM.name!(new_arg, LLVM.name(arg))
425433
end
426434
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
427-
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].air_name)
435+
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
428436
end
429437

430438
workmap[f] = new_f
@@ -591,10 +599,15 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
591599

592600
# Create metadata for argument intrinsics last
593601
for intr_arg in parameters(entry)[i:end]
602+
intr_fn = LLVM.name(intr_arg)
603+
594604
arg_info = Metadata[]
595605

596606
push!(arg_info, Metadata(ConstantInt(Int32(length(parameters(entry))-i); ctx)))
597-
push!(arg_info, MDString("air." * LLVM.name(intr_arg); ctx))
607+
push!(arg_info, MDString("air.$intr_fn" ; ctx))
608+
609+
push!(arg_info, MDString("air.arg_type_name" ; ctx))
610+
push!(arg_info, MDString(argument_type_name(llvmtype(intr_arg)); ctx))
598611

599612
arg_info = MDNode(arg_info; ctx)
600613
push!(arg_infos, arg_info)

0 commit comments

Comments
 (0)