Skip to content

Commit 0d7e89e

Browse files
committed
start generating new module
1 parent 8e00b7d commit 0d7e89e

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ pub(crate) fn run_pass_manager(
678678
if cfg!(llvm_enzyme) && enable_gpu && !thin {
679679
let cx =
680680
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
681-
crate::builder::gpu_offload::handle_gpu_code(&cx);
681+
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
682682
}
683683

684684
if cfg!(llvm_enzyme) && enable_ad && !thin {

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,62 @@ use crate::common::AsCCharPtr;
55
use crate::llvm::AttributePlace::Function;
66
use crate::llvm::{self, Linkage, build_string};
77
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};
8+
use rustc_codegen_ssa::back::write::{CodegenContext, FatLtoInput};
89

910
use llvm::Linkage::*;
1011
use rustc_abi::Align;
1112
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
1213

14+
// We don't copy types from other functions because we generate a new module and context.
15+
// Bringing in types from other contexts would likely cause issues.
16+
pub(crate) fn gen_image_wrapper_module<'ll>(
17+
cgcx: &CodegenContext<LlvmCodegenBackend>,
18+
old_cx: &SimpleCx<'ll>,
19+
) {
20+
unsafe {
21+
let llcx = llvm::LLVMRustContextCreate(false);
22+
let module_name = CString::new("offload.wrapper.module").unwrap();
23+
let llmod = llvm::LLVMModuleCreateWithNameInContext(module_name.as_ptr(), llcx);
24+
let cx = SimpleCx::new(llmod, llcx, cgcx.pointer_size);
25+
let tptr = cx.type_ptr();
26+
let ti64 = cx.type_i64();
27+
let ti32 = cx.type_i32();
28+
let ti16 = cx.type_i16();
29+
let ti8 = cx.type_i8();
30+
let dl_cstr = llvm::LLVMGetDataLayoutStr(old_cx.llmod);
31+
llvm::LLVMSetDataLayout(llmod, dl_cstr);
32+
// target triple = "x86_64-unknown-linux-gnu"
33+
34+
let mut entry_fields = [ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
35+
36+
let entry_struct_name = CString::new("__tgt_offload_entry").unwrap();
37+
let entry_struct = llvm::LLVMStructCreateNamed(llcx, entry_struct_name.as_ptr());
38+
llvm::LLVMStructSetBody(
39+
entry_struct,
40+
entry_fields.as_mut_ptr(),
41+
entry_fields.len() as u32,
42+
0,
43+
);
44+
45+
llvm::LLVMPrintModuleToFile(
46+
llmod,
47+
CString::new("rustmagic.openmp.image.wrapper.ll").unwrap().as_ptr(),
48+
std::ptr::null_mut(),
49+
);
50+
51+
// Clean up
52+
llvm::LLVMDisposeModule(llmod);
53+
llvm::LLVMContextDispose(llcx);
54+
}
55+
}
56+
1357
// first we need to add all the fun to the host module
1458
// %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
1559
// %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
16-
pub(crate) fn handle_gpu_code<'ll>(cx: &'ll SimpleCx<'_>) {
60+
pub(crate) fn handle_gpu_code<'ll>(
61+
cgcx: &CodegenContext<LlvmCodegenBackend>,
62+
cx: &'ll SimpleCx<'_>,
63+
) {
1764
if cx.get_function("gen_tgt_offload").is_some() {
1865
let (offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals(&cx);
1966

@@ -29,7 +76,7 @@ pub(crate) fn handle_gpu_code<'ll>(cx: &'ll SimpleCx<'_>) {
2976
}
3077
dbg!("gen_call_handling");
3178
gen_call_handling(&cx, &kernels, at_one, begin, update, end, fn_ty, &o_types);
32-
gen_image_wrapper_module();
79+
gen_image_wrapper_module(&cgcx, &cx);
3380
} else {
3481
dbg!("no marker found");
3582
}
@@ -413,5 +460,3 @@ fn gen_call_handling<'ll>(
413460
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
414461
}
415462
}
416-
417-
pub(crate) fn gen_image_wrapper_module() {}

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,13 @@ unsafe extern "C" {
10031003
SLen: c_uint,
10041004
) -> MetadataKindId;
10051005

1006-
// Create modules.
1006+
// Create, print, and destroy modules.
1007+
pub(crate) fn LLVMPrintModuleToFile(
1008+
M: &Module,
1009+
Name: *const c_char,
1010+
Error_message: *mut c_char,
1011+
);
1012+
pub(crate) fn LLVMDisposeModule(M: &Module);
10071013
pub(crate) fn LLVMModuleCreateWithNameInContext(
10081014
ModuleID: *const c_char,
10091015
C: &Context,

0 commit comments

Comments
 (0)