Skip to content

Commit 78fabee

Browse files
committed
generate most of the caller code
1 parent d1ef755 commit 78fabee

File tree

4 files changed

+63
-17
lines changed

4 files changed

+63
-17
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,7 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::
799799
o_types
800800
}
801801

802-
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
802+
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
803803

804804
let main_fn = cx.get_function("main");
805805
if let Some(main_fn) = main_fn {
@@ -814,32 +814,50 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, be
814814
let kernel_call_bb = unsafe {llvm::LLVMGetInstructionParent(kernel_call)};
815815
let mut builder = SBuilder::build(cx, kernel_call_bb);
816816

817+
let types = cx.func_params_types(cx.val_ty(kernel));
818+
let num_args = types.len();
819+
817820
// First we generate a few variables used for the data mappers below.
818821
// %.offload_baseptrs = alloca [3 x ptr], align 8
819822
// %.offload_ptrs = alloca [3 x ptr], align 8
820823
// %.offload_mappers = alloca [3 x ptr], align 8
821824
// %.offload_sizes = alloca [3 x i64], align 8
822825
unsafe{llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn)};
823-
let ty = cx.type_array(cx.type_ptr(), 3);
824-
builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
825-
builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs");
826-
builder.my_alloca2(ty, Align::EIGHT, ".offload_mappers");
827-
let ty = cx.type_array(cx.type_i64(), 3);
828-
builder.my_alloca2(ty, Align::EIGHT, ".offload_sizes");
829-
826+
let ty = cx.type_array(cx.type_ptr(), num_args);
827+
let a1 = builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
828+
let a2 = builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs");
829+
let a3 = builder.my_alloca2(ty, Align::EIGHT, ".offload_mappers");
830+
let ty2 = cx.type_array(cx.type_i64(), num_args);
831+
let a4 = builder.my_alloca2(ty2, Align::EIGHT, ".offload_sizes");
830832

831833
// Now we generate the __tgt_target_data calls
832834
unsafe {llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call)};
833835
dbg!("positioned builder, ready");
834836

837+
// %27 = getelementptr inbounds [3 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
838+
// %28 = getelementptr inbounds [3 x ptr], ptr %.offload_ptrs, i32 0, i32 0
839+
// %29 = getelementptr inbounds [3 x i64], ptr %.offload_sizes, i32 0, i32 0
840+
let i32_0 = cx.get_const_i32(0);
841+
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
842+
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
843+
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
844+
835845
let nullptr = cx.const_null(cx.type_ptr());
836846
let o_type = o_types[0];
837-
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(3), nullptr, nullptr, nullptr, o_type, nullptr, nullptr];
838-
dbg!(&fn_ty);
839-
dbg!(&begin);
840-
dbg!(&args);
847+
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(3), gep1, gep2, gep3, o_type, nullptr, nullptr];
841848
builder.call(fn_ty, begin, &args, None);
842-
dbg!("called begin");
849+
850+
unsafe {llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call)};
851+
dbg!("re-positioned builder, ready");
852+
853+
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
854+
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
855+
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
856+
857+
let nullptr = cx.const_null(cx.type_ptr());
858+
let o_type = o_types[0];
859+
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(num_args), gep1, gep2, gep3, o_type, nullptr, nullptr];
860+
builder.call(fn_ty, end, &args, None);
843861

844862
// 1. set insert point before kernel call.
845863
// 2. generate all the GEPS and stores.
@@ -907,7 +925,7 @@ pub(crate) fn run_pass_manager(
907925
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
908926
if cx.get_function("gen_tgt_offload").is_some() {
909927

910-
let (offload_entry_ty, at_one, foo, bar, baz, fn_ty) = gen_globals(&cx);
928+
let (offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals(&cx);
911929

912930
dbg!("created struct");
913931
let mut o_types = vec![];
@@ -918,7 +936,8 @@ pub(crate) fn run_pass_manager(
918936
// TODO: replace num by proper fn name
919937
o_types.push(gen_define_handling(&cx, offload_entry_ty, num));
920938
}
921-
gen_call_handling(&cx, at_one, foo, bar, baz, fn_ty, &o_types);
939+
let kernel = cx.get_function("kernel_1").unwrap();
940+
gen_call_handling(&cx, kernel, at_one, begin, update, end, fn_ty, &o_types);
922941
} else {
923942
dbg!("no marker found");
924943
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,26 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
133133
}
134134
val
135135
}
136+
137+
pub(crate) fn inbounds_gep(
138+
&mut self,
139+
ty: &'ll Type,
140+
ptr: &'ll Value,
141+
indices: &[&'ll Value],
142+
) -> &'ll Value {
143+
unsafe {
144+
llvm::LLVMBuildGEPWithNoWrapFlags(
145+
self.llbuilder,
146+
ty,
147+
ptr,
148+
indices.as_ptr(),
149+
indices.len() as c_uint,
150+
UNNAMED,
151+
GEPNoWrapFlags::InBounds,
152+
)
153+
}
154+
}
155+
136156
}
137157

138158
/// Empty string, to be used where LLVM expects an instruction name, indicating

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ unsafe extern "C" {
3333
kind: AttributeKind,
3434
);
3535
pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
36+
pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
3637
pub(crate) fn LLVMRustGetFunctionCall(
3738
F: &Value,
3839
name: *const c_char,

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,13 +1602,19 @@ extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
16021602
unwrap(B)->SetInsertPoint(unwrap(BB), Point);
16031603
}
16041604

1605-
extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B,
1606-
LLVMValueRef Instr) {
1605+
extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B, LLVMValueRef Instr) {
16071606
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
16081607
unwrap(B)->SetInsertPoint(I);
16091608
}
16101609
}
16111610

1611+
extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) {
1612+
if (auto I = dyn_cast<Instruction>(unwrap<Value>(Instr))) {
1613+
auto J = I->getNextNonDebugInstruction();
1614+
unwrap(B)->SetInsertPoint(J);
1615+
}
1616+
}
1617+
16121618
extern "C" LLVMValueRef
16131619
LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) {
16141620
auto targetName = StringRef(Name, NameLen);

0 commit comments

Comments
 (0)