@@ -742,8 +742,20 @@ fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value
742
742
743
743
fn gen_define_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm:: Value , offload_entry_ty : & ' ll llvm:: Type , num : i64 ) -> & ' ll llvm:: Value {
744
744
let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
745
- let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 8u64 , 0 , 16 , 0 ] ) ;
746
- let o_types = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 800u64 , 544 , 547 , 544 ] ) ;
745
+ // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
746
+ // reference) types.
747
+ let num_ptr_types = types. iter ( ) . map ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) ) . count ( ) ;
748
+
749
+ // We do not know their size anymore at this level, so hardcode a placeholder.
750
+ // A follow-up pr will track these from the frontend, where we still have Rust types.
751
+ // Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes.
752
+ // I decided that 1024 bytes is a great placeholder value for now.
753
+ let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 1024 ; num_ptr_types] ) ;
754
+ // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
755
+ // or both to and from the gpu (=3). Other values shouldn't affect us for now.
756
+ // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
757
+ // will be 2. For now, everything is 3, untill we have our frontend set up.
758
+ let o_types = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 3 ; num_ptr_types] ) ;
747
759
// Next: For each function, generate these three entries. A weak constant,
748
760
// the llvm.rodata entry name, and the omp_offloading_entries value
749
761
@@ -794,11 +806,11 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, off
794
806
o_types
795
807
}
796
808
797
- fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm:: Value , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
809
+ fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernels : & [ & ' ll llvm:: Value ] , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
798
810
799
811
let main_fn = cx. get_function ( "main" ) ;
800
812
if let Some ( main_fn) = main_fn {
801
- let kernel_name = "kernel_1" ; //name.as_c_char_ptr(), name.len)
813
+ let kernel_name = "kernel_1" ;
802
814
let call = unsafe { llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) ) } ;
803
815
let kernel_call = if call. is_some ( ) {
804
816
dbg ! ( "found kernel call" ) ;
@@ -809,38 +821,36 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, s_ide
809
821
let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
810
822
let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
811
823
812
- let types = cx. func_params_types ( cx. get_type_of_global ( kernel ) ) ;
824
+ let types = cx. func_params_types ( cx. get_type_of_global ( kernels [ 0 ] ) ) ;
813
825
dbg ! ( & types) ;
814
826
let num_args = types. len ( ) as u64 ;
815
827
816
828
// First we generate a few variables used for the data mappers below.
817
- // %.offload_baseptrs = alloca [3 x ptr], align 8
818
- // %.offload_ptrs = alloca [3 x ptr], align 8
819
- // %.offload_mappers = alloca [3 x ptr], align 8
820
- // %.offload_sizes = alloca [3 x i64], align 8
821
829
unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
822
830
let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
831
+
832
+ // Baseptr are just the input pointer to the kernel, stored in a local alloca
823
833
let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
834
+
835
+ // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
824
836
let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
825
- let a3 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_mappers" ) ;
837
+
838
+ // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
826
839
let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
827
840
let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
828
841
829
842
// Now we generate the __tgt_target_data calls
830
843
unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
831
844
dbg ! ( "positioned builder, ready" ) ;
832
845
833
- // %27 = getelementptr inbounds [3 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
834
- // %28 = getelementptr inbounds [3 x ptr], ptr %.offload_ptrs, i32 0, i32 0
835
- // %29 = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
836
846
let i32_0 = cx. get_const_i32 ( 0 ) ;
837
847
let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
838
848
let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
839
849
let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
840
850
841
851
let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
842
852
let o_type = o_types[ 0 ] ;
843
- let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( 3 ) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
853
+ let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args ) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
844
854
builder. call ( fn_ty, begin, & args, None ) ;
845
855
846
856
unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
@@ -925,15 +935,16 @@ pub(crate) fn run_pass_manager(
925
935
926
936
dbg ! ( "created struct" ) ;
927
937
let mut o_types = vec ! [ ] ;
938
+ let mut kernels = vec ! [ ] ;
928
939
for num in 0 ..9 {
929
940
let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
930
941
if let Some ( kernel) = kernel{
931
942
o_types. push ( gen_define_handling ( & cx, kernel, offload_entry_ty, num) ) ;
943
+ kernels. push ( kernel) ;
932
944
}
933
945
}
934
- let kernel = cx. get_function ( "kernel_1" ) . unwrap ( ) ;
935
946
dbg ! ( "gen_call_handling" ) ;
936
- gen_call_handling ( & cx, kernel , at_one, begin, update, end, fn_ty, & o_types) ;
947
+ gen_call_handling ( & cx, & kernels , at_one, begin, update, end, fn_ty, & o_types) ;
937
948
} else {
938
949
dbg ! ( "no marker found" ) ;
939
950
}
0 commit comments