@@ -740,15 +740,10 @@ fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value
740
740
741
741
742
742
743
- fn gen_define_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , offload_entry_ty : & ' ll llvm:: Type , num : i64 ) -> & ' ll llvm:: Value {
744
- // We add a pair of sizes and maptypes per offloadable function.
745
- // @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
743
+ fn gen_define_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm:: Value , offload_entry_ty : & ' ll llvm:: Type , num : i64 ) -> & ' ll llvm:: Value {
744
+ let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
746
745
let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 8u64 , 0 , 16 , 0 ] ) ;
747
746
let o_types = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 800u64 , 544 , 547 , 544 ] ) ;
748
- // TODO: We should add another pair per call to offloadable functions
749
- // @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
750
- // @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
751
-
752
747
// Next: For each function, generate these three entries. A weak constant,
753
748
// the llvm.rodata entry name, and the omp_offloading_entries value
754
749
@@ -814,8 +809,9 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, s_ide
814
809
let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
815
810
let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
816
811
817
- let types = cx. func_params_types ( cx. val_ty ( kernel) ) ;
818
- let num_args = types. len ( ) ;
812
+ let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
813
+ dbg ! ( & types) ;
814
+ let num_args = types. len ( ) as u64 ;
819
815
820
816
// First we generate a few variables used for the data mappers below.
821
817
// %.offload_baseptrs = alloca [3 x ptr], align 8
@@ -930,19 +926,17 @@ pub(crate) fn run_pass_manager(
930
926
dbg ! ( "created struct" ) ;
931
927
let mut o_types = vec ! [ ] ;
932
928
for num in 0 ..9 {
933
- if !cx. get_function ( & format ! ( "kernel_{num}" ) ) . is_some ( ) {
934
- continue ;
929
+ let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
930
+ if let Some ( kernel) = kernel{
931
+ o_types. push ( gen_define_handling ( & cx, kernel, offload_entry_ty, num) ) ;
935
932
}
936
- // TODO: replace num by proper fn name
937
- o_types. push ( gen_define_handling ( & cx, offload_entry_ty, num) ) ;
938
933
}
939
934
let kernel = cx. get_function ( "kernel_1" ) . unwrap ( ) ;
935
+ dbg ! ( "gen_call_handling" ) ;
940
936
gen_call_handling ( & cx, kernel, at_one, begin, update, end, fn_ty, & o_types) ;
941
937
} else {
942
938
dbg ! ( "no marker found" ) ;
943
939
}
944
- } else {
945
- dbg ! ( "Not creating struct" ) ;
946
940
}
947
941
948
942
if cfg ! ( llvm_enzyme) && enable_ad && !thin {
0 commit comments