Skip to content

Commit 8f8ca00

Browse files
committed
bugfix and doc improvements
1 parent ad03dcc commit 8f8ca00

File tree

1 file changed

+27
-16
lines changed
  • compiler/rustc_codegen_llvm/src/back

1 file changed

+27
-16
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,20 @@ fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value
742742

743743
fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, num: i64) -> &'ll llvm::Value {
744744
let types = cx.func_params_types(cx.get_type_of_global(kernel));
745-
let o_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![8u64,0,16,0]);
746-
let o_types = add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![800u64, 544, 547, 544]);
745+
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
746+
// reference) types.
747+
let num_ptr_types = types.iter().map(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)).count();
748+
749+
// We do not know their size anymore at this level, so hardcode a placeholder.
750+
// A follow-up pr will track these from the frontend, where we still have Rust types.
751+
// Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes.
752+
// I decided that 1024 bytes is a great placeholder value for now.
753+
let o_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024;num_ptr_types]);
754+
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
755+
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
756+
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
757+
// will be 2. For now, everything is 3, untill we have our frontend set up.
758+
let o_types = add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3;num_ptr_types]);
747759
// Next: For each function, generate these three entries. A weak constant,
748760
// the llvm.rodata entry name, and the omp_offloading_entries value
749761

@@ -794,11 +806,11 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, off
794806
o_types
795807
}
796808

797-
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
809+
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
798810

799811
let main_fn = cx.get_function("main");
800812
if let Some(main_fn) = main_fn {
801-
let kernel_name = "kernel_1";//name.as_c_char_ptr(), name.len)
813+
let kernel_name = "kernel_1";
802814
let call = unsafe {llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())};
803815
let kernel_call = if call.is_some() {
804816
dbg!("found kernel call");
@@ -809,38 +821,36 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, s_ide
809821
let kernel_call_bb = unsafe {llvm::LLVMGetInstructionParent(kernel_call)};
810822
let mut builder = SBuilder::build(cx, kernel_call_bb);
811823

812-
let types = cx.func_params_types(cx.get_type_of_global(kernel));
824+
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
813825
dbg!(&types);
814826
let num_args = types.len() as u64;
815827

816828
// First we generate a few variables used for the data mappers below.
817-
// %.offload_baseptrs = alloca [3 x ptr], align 8
818-
// %.offload_ptrs = alloca [3 x ptr], align 8
819-
// %.offload_mappers = alloca [3 x ptr], align 8
820-
// %.offload_sizes = alloca [3 x i64], align 8
821829
unsafe{llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn)};
822830
let ty = cx.type_array(cx.type_ptr(), num_args);
831+
832+
// Baseptr are just the input pointer to the kernel, stored in a local alloca
823833
let a1 = builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
834+
835+
// Ptrs are the result of a gep into the baseptr, at least for our trivial types.
824836
let a2 = builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs");
825-
let a3 = builder.my_alloca2(ty, Align::EIGHT, ".offload_mappers");
837+
838+
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
826839
let ty2 = cx.type_array(cx.type_i64(), num_args);
827840
let a4 = builder.my_alloca2(ty2, Align::EIGHT, ".offload_sizes");
828841

829842
// Now we generate the __tgt_target_data calls
830843
unsafe {llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call)};
831844
dbg!("positioned builder, ready");
832845

833-
// %27 = getelementptr inbounds [3 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
834-
// %28 = getelementptr inbounds [3 x ptr], ptr %.offload_ptrs, i32 0, i32 0
835-
// %29 = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
836846
let i32_0 = cx.get_const_i32(0);
837847
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
838848
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
839849
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
840850

841851
let nullptr = cx.const_null(cx.type_ptr());
842852
let o_type = o_types[0];
843-
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(3), gep1, gep2, gep3, o_type, nullptr, nullptr];
853+
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(num_args), gep1, gep2, gep3, o_type, nullptr, nullptr];
844854
builder.call(fn_ty, begin, &args, None);
845855

846856
unsafe {llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call)};
@@ -925,15 +935,16 @@ pub(crate) fn run_pass_manager(
925935

926936
dbg!("created struct");
927937
let mut o_types = vec![];
938+
let mut kernels = vec![];
928939
for num in 0..9 {
929940
let kernel = cx.get_function(&format!("kernel_{num}"));
930941
if let Some(kernel) = kernel{
931942
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
943+
kernels.push(kernel);
932944
}
933945
}
934-
let kernel = cx.get_function("kernel_1").unwrap();
935946
dbg!("gen_call_handling");
936-
gen_call_handling(&cx, kernel, at_one, begin, update, end, fn_ty, &o_types);
947+
gen_call_handling(&cx, &kernels, at_one, begin, update, end, fn_ty, &o_types);
937948
} else {
938949
dbg!("no marker found");
939950
}

0 commit comments

Comments
 (0)