@@ -799,7 +799,7 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::
799
799
o_types
800
800
}
801
801
802
- fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
802
+ fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm :: Value , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
803
803
804
804
let main_fn = cx. get_function ( "main" ) ;
805
805
if let Some ( main_fn) = main_fn {
@@ -814,32 +814,50 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, be
814
814
let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
815
815
let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
816
816
817
+ let types = cx. func_params_types ( cx. val_ty ( kernel) ) ;
818
+ let num_args = types. len ( ) ;
819
+
817
820
// First we generate a few variables used for the data mappers below.
818
821
// %.offload_baseptrs = alloca [3 x ptr], align 8
819
822
// %.offload_ptrs = alloca [3 x ptr], align 8
820
823
// %.offload_mappers = alloca [3 x ptr], align 8
821
824
// %.offload_sizes = alloca [3 x i64], align 8
822
825
unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
823
- let ty = cx. type_array ( cx. type_ptr ( ) , 3 ) ;
824
- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
825
- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
826
- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_mappers" ) ;
827
- let ty = cx. type_array ( cx. type_i64 ( ) , 3 ) ;
828
- builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_sizes" ) ;
829
-
826
+ let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
827
+ let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
828
+ let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
829
+ let a3 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_mappers" ) ;
830
+ let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
831
+ let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
830
832
831
833
// Now we generate the __tgt_target_data calls
832
834
unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
833
835
dbg ! ( "positioned builder, ready" ) ;
834
836
837
+ // %27 = getelementptr inbounds [3 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
838
+ // %28 = getelementptr inbounds [3 x ptr], ptr %.offload_ptrs, i32 0, i32 0
839
+ // %29 = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
840
+ let i32_0 = cx. get_const_i32 ( 0 ) ;
841
+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
842
+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
843
+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
844
+
835
845
let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
836
846
let o_type = o_types[ 0 ] ;
837
- let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( 3 ) , nullptr, nullptr, nullptr, o_type, nullptr, nullptr] ;
838
- dbg ! ( & fn_ty) ;
839
- dbg ! ( & begin) ;
840
- dbg ! ( & args) ;
847
+ let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( 3 ) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
841
848
builder. call ( fn_ty, begin, & args, None ) ;
842
- dbg ! ( "called begin" ) ;
849
+
850
+ unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
851
+ dbg ! ( "re-positioned builder, ready" ) ;
852
+
853
+ let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
854
+ let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
855
+ let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
856
+
857
+ let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
858
+ let o_type = o_types[ 0 ] ;
859
+ let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
860
+ builder. call ( fn_ty, end, & args, None ) ;
843
861
844
862
// 1. set insert point before kernel call.
845
863
// 2. generate all the GEPS and stores.
@@ -907,7 +925,7 @@ pub(crate) fn run_pass_manager(
907
925
SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
908
926
if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
909
927
910
- let ( offload_entry_ty, at_one, foo , bar , baz , fn_ty) = gen_globals ( & cx) ;
928
+ let ( offload_entry_ty, at_one, begin , update , end , fn_ty) = gen_globals ( & cx) ;
911
929
912
930
dbg ! ( "created struct" ) ;
913
931
let mut o_types = vec ! [ ] ;
@@ -918,7 +936,8 @@ pub(crate) fn run_pass_manager(
918
936
// TODO: replace num by proper fn name
919
937
o_types. push ( gen_define_handling ( & cx, offload_entry_ty, num) ) ;
920
938
}
921
- gen_call_handling ( & cx, at_one, foo, bar, baz, fn_ty, & o_types) ;
939
+ let kernel = cx. get_function ( "kernel_1" ) . unwrap ( ) ;
940
+ gen_call_handling ( & cx, kernel, at_one, begin, update, end, fn_ty, & o_types) ;
922
941
} else {
923
942
dbg ! ( "no marker found" ) ;
924
943
}
0 commit comments