Skip to content

Commit 3a8e8a6

Browse files
committed
last steps for mem handling
1 parent 8f8ca00 commit 3a8e8a6

File tree

2 files changed

+40
-18
lines changed

2 files changed

+40
-18
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use llvm::Linkage::*;
2828
use crate::back::write::{
2929
self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode,
3030
};
31-
use crate::builder::SBuilder;
31+
use crate::builder::{SBuilder, UNNAMED};
3232
use crate::errors::{
3333
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
3434
};
@@ -806,6 +806,27 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, off
806806
o_types
807807
}
808808

809+
810+
// For each kernel *call*, we now use some of our previous declared globals to move data to and from
811+
// the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
812+
// from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
813+
// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
814+
// Since in our frontend users (by default) don't have to specify data transfer, this is something
815+
// we should optimize in the future! We also assume that everything should be copied back and forth,
816+
// but sometimes we can directly zero-allocate on the device and only move back, or if something is
817+
// immutable, we might only copy it to the device, but not back.
818+
//
819+
// Current steps:
820+
// 0. Alloca some variables for the following steps
821+
// 1. set insert point before kernel call.
822+
// 2. generate all the GEPS and stores, to be used in 3)
823+
// 3. generate __tgt_target_data_begin calls to move data to the GPU
824+
//
825+
// unchanged: keep kernel call. Later move the kernel to the GPU
826+
//
827+
// 4. set insert point after kernel call.
828+
// 5. generate all the GEPS and stores, to be used in 6)
829+
// 6. generate __tgt_target_data_end calls to move data from the GPU
809830
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
810831

811832
let main_fn = cx.get_function("main");
@@ -819,30 +840,39 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
819840
return;
820841
};
821842
let kernel_call_bb = unsafe {llvm::LLVMGetInstructionParent(kernel_call)};
843+
let called = unsafe {llvm::LLVMGetCalledValue(kernel_call)};
822844
let mut builder = SBuilder::build(cx, kernel_call_bb);
823845

824-
let types = cx.func_params_types(cx.get_type_of_global(kernels[0]));
846+
let types = cx.func_params_types(cx.get_type_of_global(called));
825847
dbg!(&types);
826848
let num_args = types.len() as u64;
849+
let mut names: Vec<&llvm::Value> = Vec::with_capacity(num_args);
827850

828-
// First we generate a few variables used for the data mappers below.
851+
// Step 0)
829852
unsafe{llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn)};
830853
let ty = cx.type_array(cx.type_ptr(), num_args);
831-
832854
// Baseptr are just the input pointer to the kernel, stored in a local alloca
833855
let a1 = builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
834-
835856
// Ptrs are the result of a gep into the baseptr, at least for our trivial types.
836857
let a2 = builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs");
837-
838858
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
839859
let ty2 = cx.type_array(cx.type_i64(), num_args);
840860
let a4 = builder.my_alloca2(ty2, Align::EIGHT, ".offload_sizes");
861+
// Now we allocate once per function param, a copy to be passed to one of our maps.
862+
for (index, in_ty) in types.iter().enumerate() {
863+
// Todo:
864+
let p = llvm::get_param(called, index as u32);
865+
let name = llvm::get_value_name(p);
866+
let arg_name = format!("{name}.addr");
867+
let alloca = unsafe {llvm::LLVMBuildAlloca(builder.llbuilder, in_ty, arg_name)};
868+
// get function arg, store it into the alloca, and read it.
869+
}
841870

842-
// Now we generate the __tgt_target_data calls
871+
872+
// Step 1)
843873
unsafe {llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call)};
844-
dbg!("positioned builder, ready");
845874

875+
// Step 2)
846876
let i32_0 = cx.get_const_i32(0);
847877
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
848878
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
@@ -853,8 +883,8 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
853883
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(num_args), gep1, gep2, gep3, o_type, nullptr, nullptr];
854884
builder.call(fn_ty, begin, &args, None);
855885

886+
// Step 4)
856887
unsafe {llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call)};
857-
dbg!("re-positioned builder, ready");
858888

859889
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
860890
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
@@ -865,15 +895,6 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
865895
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(num_args), gep1, gep2, gep3, o_type, nullptr, nullptr];
866896
builder.call(fn_ty, end, &args, None);
867897

868-
// 1. set insert point before kernel call.
869-
// 2. generate all the GEPS and stores.
870-
// 3. generate __tgt_target_data calls.
871-
//
872-
// unchanged: keep kernel call.
873-
//
874-
// 4. generate all the GEPS and stores.
875-
// 5. generate __tgt_target_data calls
876-
877898
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
878899
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
879900
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@ unsafe extern "C" {
12151215

12161216
// Operations on instructions
12171217
pub(crate) fn LLVMGetInstructionParent(Inst: &Value) -> &BasicBlock;
1218+
pub(crate) fn LLVMGetCalledValue(CallInst: &Value) -> Option<&Value>;
12181219
pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>;
12191220
pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock;
12201221
pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>;

0 commit comments

Comments
 (0)