@@ -339,33 +339,41 @@ for intr in [
339
339
" dispatch_quadgroups_per_threadgroup" , " dispatch_simdgroups_per_threadgroup" ,
340
340
" quadgroup_index_in_threadgroup" , " quadgroups_per_threadgroup" ,
341
341
" simdgroup_index_in_threadgroup" , " simdgroups_per_threadgroup" ,
342
- " thread_index_in_quadgroup" , " thread_index_in_simdgroup" , " thread_index_in_threadgroup " ,
343
- " thread_execution_width" , " threads_per_simdgroup" ],
344
- (intr_typ, air_typ , julia_typ) in [
345
- (" i32" , " uint " , UInt32),
346
- (" i16" , " ushort " , UInt16),
342
+ " thread_index_in_quadgroup" , " thread_index_in_simdgroup" ,
343
+ " thread_index_in_threadgroup " , " thread_execution_width" , " threads_per_simdgroup" ],
344
+ (llvm_typ , julia_typ) in [
345
+ (" i32" , UInt32),
346
+ (" i16" , UInt16),
347
347
]
348
- push! (kernel_intrinsics,
349
- " julia.air.$intr .$intr_typ " =>
350
- (air_intr= " $intr .$air_typ " , air_typ, air_name= intr, julia_typ))
348
+ push! (kernel_intrinsics, " julia.air.$intr .$llvm_typ " => (name= intr, typ= julia_typ))
351
349
end
352
350
for intr in [
353
351
" dispatch_threads_per_threadgroup" ,
354
352
" grid_origin" , " grid_size" ,
355
353
" thread_position_in_grid" , " thread_position_in_threadgroup" ,
356
354
" threadgroup_position_in_grid" , " threadgroups_per_grid" ,
357
355
" threads_per_grid" , " threads_per_threadgroup" ],
358
- (intr_typ, air_typ , julia_typ) in [
359
- (" i32" , " uint " , UInt32),
360
- (" v2i32" , " uint2 " , NTuple{2 , VecElement{UInt32}}),
361
- (" v3i32" , " uint3 " , NTuple{3 , VecElement{UInt32}}),
362
- (" i16" , " ushort " , UInt16),
363
- (" v2i16" , " ushort2 " , NTuple{2 , VecElement{UInt16}}),
364
- (" v3i16" , " ushort3 " , NTuple{3 , VecElement{UInt16}}),
356
+ (llvm_typ , julia_typ) in [
357
+ (" i32" , UInt32),
358
+ (" v2i32" , NTuple{2 , VecElement{UInt32}}),
359
+ (" v3i32" , NTuple{3 , VecElement{UInt32}}),
360
+ (" i16" , UInt16),
361
+ (" v2i16" , NTuple{2 , VecElement{UInt16}}),
362
+ (" v3i16" , NTuple{3 , VecElement{UInt16}}),
365
363
]
366
- push! (kernel_intrinsics,
367
- " julia.air.$intr .$intr_typ " =>
368
- (air_intr= " $intr .$air_typ " , air_typ, air_name= intr, julia_typ))
364
+ push! (kernel_intrinsics, " julia.air.$intr .$llvm_typ " => (name= intr, typ= julia_typ))
365
+ end
366
+
367
+ function argument_type_name (typ)
368
+ if typ isa LLVM. IntegerType && width (typ) == 16
369
+ " ushort"
370
+ elseif typ isa LLVM. IntegerType && width (typ) == 32
371
+ " uint"
372
+ elseif typ isa LLVM. VectorType
373
+ argument_type_name (eltype (typ)) * string (Int (size (typ)))
374
+ else
375
+ error (" Cannot encode unknown type `$typ `" )
376
+ end
369
377
end
370
378
371
379
function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
@@ -414,7 +422,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
414
422
new_param_types = LLVMType[parameters (ft)... ]
415
423
416
424
for intr_fn in used_intrinsics
417
- llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. julia_typ ; ctx)
425
+ llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ ; ctx)
418
426
push! (new_param_types, llvm_typ)
419
427
end
420
428
new_ft = LLVM. FunctionType (LLVM. return_type (ft), new_param_types)
@@ -424,7 +432,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
424
432
LLVM. name! (new_arg, LLVM. name (arg))
425
433
end
426
434
for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
427
- LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. air_name )
435
+ LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name )
428
436
end
429
437
430
438
workmap[f] = new_f
@@ -557,12 +565,13 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
557
565
md = Metadata[]
558
566
559
567
# argument index
560
- push! (md, Metadata (ConstantInt (Int32 (arg. codegen. i- 1 ); ctx)))
568
+ @assert arg. codegen. i == i
569
+ push! (md, Metadata (ConstantInt (Int32 (i- 1 ); ctx)))
561
570
562
571
push! (md, MDString (" air.buffer" ; ctx))
563
572
564
573
push! (md, MDString (" air.location_index" ; ctx))
565
- push! (md, Metadata (ConstantInt (Int32 (arg . codegen . i- 1 ); ctx)))
574
+ push! (md, Metadata (ConstantInt (Int32 (i- 1 ); ctx)))
566
575
567
576
# XXX : unknown
568
577
push! (md, Metadata (ConstantInt (Int32 (1 ); ctx)))
@@ -591,13 +600,20 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
591
600
592
601
# Create metadata for argument intrinsics last
593
602
for intr_arg in parameters (entry)[i: end ]
603
+ intr_fn = LLVM. name (intr_arg)
604
+
594
605
arg_info = Metadata[]
595
606
596
- push! (arg_info, Metadata (ConstantInt (Int32 (length (parameters (entry))- i); ctx)))
597
- push! (arg_info, MDString (" air." * LLVM. name (intr_arg); ctx))
607
+ push! (arg_info, Metadata (ConstantInt (Int32 (i- 1 ); ctx)))
608
+ push! (arg_info, MDString (" air.$intr_fn " ; ctx))
609
+
610
+ push! (arg_info, MDString (" air.arg_type_name" ; ctx))
611
+ push! (arg_info, MDString (argument_type_name (llvmtype (intr_arg)); ctx))
598
612
599
613
arg_info = MDNode (arg_info; ctx)
600
614
push! (arg_infos, arg_info)
615
+
616
+ i += 1
601
617
end
602
618
arg_infos = MDNode (arg_infos; ctx)
603
619
0 commit comments