@@ -339,33 +339,41 @@ for intr in [
339
339
" dispatch_quadgroups_per_threadgroup" , " dispatch_simdgroups_per_threadgroup" ,
340
340
" quadgroup_index_in_threadgroup" , " quadgroups_per_threadgroup" ,
341
341
" simdgroup_index_in_threadgroup" , " simdgroups_per_threadgroup" ,
342
- " thread_index_in_quadgroup" , " thread_index_in_simdgroup" , " thread_index_in_threadgroup " ,
343
- " thread_execution_width" , " threads_per_simdgroup" ],
344
- (intr_typ, air_typ , julia_typ) in [
345
- (" i32" , " uint " , UInt32),
346
- (" i16" , " ushort " , UInt16),
342
+ " thread_index_in_quadgroup" , " thread_index_in_simdgroup" ,
343
+ " thread_index_in_threadgroup " , " thread_execution_width" , " threads_per_simdgroup" ],
344
+ (llvm_typ , julia_typ) in [
345
+ (" i32" , UInt32),
346
+ (" i16" , UInt16),
347
347
]
348
- push! (kernel_intrinsics,
349
- " julia.air.$intr .$intr_typ " =>
350
- (air_intr= " $intr .$air_typ " , air_typ, air_name= intr, julia_typ))
348
+ push! (kernel_intrinsics, " julia.air.$intr .$llvm_typ " => (name= intr, typ= julia_typ))
351
349
end
352
350
for intr in [
353
351
" dispatch_threads_per_threadgroup" ,
354
352
" grid_origin" , " grid_size" ,
355
353
" thread_position_in_grid" , " thread_position_in_threadgroup" ,
356
354
" threadgroup_position_in_grid" , " threadgroups_per_grid" ,
357
355
" threads_per_grid" , " threads_per_threadgroup" ],
358
- (intr_typ, air_typ , julia_typ) in [
359
- (" i32" , " uint " , UInt32),
360
- (" v2i32" , " uint2 " , NTuple{2 , VecElement{UInt32}}),
361
- (" v3i32" , " uint3 " , NTuple{3 , VecElement{UInt32}}),
362
- (" i16" , " ushort " , UInt16),
363
- (" v2i16" , " ushort2 " , NTuple{2 , VecElement{UInt16}}),
364
- (" v3i16" , " ushort3 " , NTuple{3 , VecElement{UInt16}}),
356
+ (llvm_typ , julia_typ) in [
357
+ (" i32" , UInt32),
358
+ (" v2i32" , NTuple{2 , VecElement{UInt32}}),
359
+ (" v3i32" , NTuple{3 , VecElement{UInt32}}),
360
+ (" i16" , UInt16),
361
+ (" v2i16" , NTuple{2 , VecElement{UInt16}}),
362
+ (" v3i16" , NTuple{3 , VecElement{UInt16}}),
365
363
]
366
- push! (kernel_intrinsics,
367
- " julia.air.$intr .$intr_typ " =>
368
- (air_intr= " $intr .$air_typ " , air_typ, air_name= intr, julia_typ))
364
+ push! (kernel_intrinsics, " julia.air.$intr .$llvm_typ " => (name= intr, typ= julia_typ))
365
+ end
366
+
367
+ function argument_type_name (typ)
368
+ if typ isa LLVM. IntegerType && width (typ) == 16
369
+ " ushort"
370
+ elseif typ isa LLVM. IntegerType && width (typ) == 32
371
+ " uint"
372
+ elseif typ isa LLVM. VectorType
373
+ argument_type_name (eltype (typ)) * string (Int (size (typ)))
374
+ else
375
+ error (" Cannot encode unknown type `$typ `" )
376
+ end
369
377
end
370
378
371
379
function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
@@ -414,7 +422,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
414
422
new_param_types = LLVMType[parameters (ft)... ]
415
423
416
424
for intr_fn in used_intrinsics
417
- llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. julia_typ ; ctx)
425
+ llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ ; ctx)
418
426
push! (new_param_types, llvm_typ)
419
427
end
420
428
new_ft = LLVM. FunctionType (LLVM. return_type (ft), new_param_types)
@@ -424,7 +432,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
424
432
LLVM. name! (new_arg, LLVM. name (arg))
425
433
end
426
434
for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
427
- LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. air_name )
435
+ LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name )
428
436
end
429
437
430
438
workmap[f] = new_f
@@ -591,10 +599,15 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
591
599
592
600
# Create metadata for argument intrinsics last
593
601
for intr_arg in parameters (entry)[i: end ]
602
+ intr_fn = LLVM. name (intr_arg)
603
+
594
604
arg_info = Metadata[]
595
605
596
606
push! (arg_info, Metadata (ConstantInt (Int32 (length (parameters (entry))- i); ctx)))
597
- push! (arg_info, MDString (" air." * LLVM. name (intr_arg); ctx))
607
+ push! (arg_info, MDString (" air.$intr_fn " ; ctx))
608
+
609
+ push! (arg_info, MDString (" air.arg_type_name" ; ctx))
610
+ push! (arg_info, MDString (argument_type_name (llvmtype (intr_arg)); ctx))
598
611
599
612
arg_info = MDNode (arg_info; ctx)
600
613
push! (arg_infos, arg_info)
0 commit comments