Skip to content

Commit 8e00b7d

Browse files
committed
extract into own gpu builder module
1 parent a786d6b commit 8e00b7d

File tree

4 files changed

+423
-311
lines changed

4 files changed

+423
-311
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 3 additions & 309 deletions
Original file line numberDiff line numberDiff line change
@@ -633,292 +633,6 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
633633
llvm::set_rust_rules(true);
634634
}
635635

636-
fn gen_globals<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
637-
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
638-
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
639-
let tptr = cx.type_ptr();
640-
let ti64 = cx.type_i64();
641-
let ti32 = cx.type_i32();
642-
let ti16 = cx.type_i16();
643-
let ti8 = cx.type_i8();
644-
let tarr = cx.type_array(ti32, 3);
645-
646-
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
647-
let unknown_txt = ";unknown;unknown;0;0;;";
648-
let c_entry_name = CString::new(unknown_txt).unwrap();
649-
let c_val = c_entry_name.as_bytes_with_nul();
650-
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
651-
let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
652-
llvm::set_alignment(at_zero, Align::ONE);
653-
654-
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
655-
let struct_ident_ty = cx.type_named_struct("struct.ident_t");
656-
let struct_elems: Vec<&llvm::Value> = vec![cx.get_const_i32(0), cx.get_const_i32(2), cx.get_const_i32(0), cx.get_const_i32(22), at_zero];
657-
let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
658-
let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
659-
cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
660-
let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
661-
llvm::set_alignment(at_one, Align::EIGHT);
662-
663-
// coppied from LLVM
664-
// typedef struct {
665-
// uint64_t Reserved;
666-
// uint16_t Version;
667-
// uint16_t Kind;
668-
// uint32_t Flags;
669-
// void *Address;
670-
// char *SymbolName;
671-
// uint64_t Size;
672-
// uint64_t Data;
673-
// void *AuxAddr;
674-
// } __tgt_offload_entry;
675-
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
676-
let kernel_elements = vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
677-
678-
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
679-
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
680-
let global = cx.declare_global("my_struct_global", offload_entry_ty);
681-
let global = cx.declare_global("my_struct_global2", kernel_arguments_ty);
682-
//@my_struct_global = external global %struct.__tgt_offload_entry
683-
//@my_struct_global2 = external global %struct.__tgt_kernel_arguments
684-
dbg!(&offload_entry_ty);
685-
dbg!(&kernel_arguments_ty);
686-
//LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy};
687-
//LLVMStructSetBody(structTy, elements, 9, 0);
688-
689-
// New, to test memtransfer
690-
// ; Function Attrs: nounwind
691-
// declare void @__tgt_target_data_begin_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3
692-
//
693-
// ; Function Attrs: nounwind
694-
// declare void @__tgt_target_data_update_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3
695-
//
696-
// ; Function Attrs: nounwind
697-
// declare void @__tgt_target_data_end_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3
698-
699-
let mapper_begin = "__tgt_target_data_begin_mapper";
700-
let mapper_update = String::from("__tgt_target_data_update_mapper");
701-
let mapper_end = String::from("__tgt_target_data_end_mapper");
702-
let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
703-
let mapper_fn_ty = cx.type_func(&args, cx.type_void());
704-
let foo = crate::declare::declare_simple_fn(&cx, &mapper_begin, llvm::CallConv::CCallConv, llvm::UnnamedAddr::No, llvm::Visibility::Default, mapper_fn_ty);
705-
let bar = crate::declare::declare_simple_fn(&cx, &mapper_update, llvm::CallConv::CCallConv, llvm::UnnamedAddr::No, llvm::Visibility::Default, mapper_fn_ty);
706-
let baz = crate::declare::declare_simple_fn(&cx, &mapper_end, llvm::CallConv::CCallConv, llvm::UnnamedAddr::No, llvm::Visibility::Default, mapper_fn_ty);
707-
let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
708-
attributes::apply_to_llfn(foo, Function, &[nounwind]);
709-
attributes::apply_to_llfn(bar, Function, &[nounwind]);
710-
attributes::apply_to_llfn(baz, Function, &[nounwind]);
711-
712-
(offload_entry_ty, at_one, foo, bar, baz, mapper_fn_ty)
713-
}
714-
715-
fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
716-
let ti64 = cx.type_i64();
717-
let size_ty = cx.type_array(ti64, vals.len() as u64);
718-
let mut size_val = Vec::with_capacity(vals.len());
719-
for &val in vals {
720-
size_val.push(cx.get_const_i64(val));
721-
}
722-
let initializer = cx.const_array(ti64, &size_val);
723-
add_unnamed_global(cx, name, initializer, PrivateLinkage)
724-
}
725-
726-
fn add_unnamed_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value, l: Linkage) -> &'ll llvm::Value {
727-
let llglobal = add_global(cx, name, initializer, l);
728-
unsafe {llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global)};
729-
llglobal
730-
}
731-
732-
fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value, l: Linkage) -> &'ll llvm::Value {
733-
let c_name = CString::new(name).unwrap();
734-
let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
735-
llvm::set_global_constant(llglobal, true);
736-
llvm::set_linkage(llglobal, l);
737-
llvm::set_initializer(llglobal, initializer);
738-
llglobal
739-
}
740-
741-
742-
743-
fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, num: i64) -> &'ll llvm::Value {
744-
let types = cx.func_params_types(cx.get_type_of_global(kernel));
745-
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
746-
// reference) types.
747-
let num_ptr_types = types.iter().map(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)).count();
748-
749-
// We do not know their size anymore at this level, so hardcode a placeholder.
750-
// A follow-up pr will track these from the frontend, where we still have Rust types.
751-
// Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes.
752-
// I decided that 1024 bytes is a great placeholder value for now.
753-
let o_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024;num_ptr_types]);
754-
// Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
755-
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
756-
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
757-
// will be 2. For now, everything is 3, untill we have our frontend set up.
758-
let o_types = add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3;num_ptr_types]);
759-
// Next: For each function, generate these three entries. A weak constant,
760-
// the llvm.rodata entry name, and the omp_offloading_entries value
761-
762-
// @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
763-
// @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
764-
let name = format!(".kernel_{num}.region_id");
765-
let initializer = cx.get_const_i8(0);
766-
let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
767-
768-
let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
769-
let c_val = c_entry_name.as_bytes_with_nul();
770-
let foo = format!(".offloading.entry_name.{num}");
771-
772-
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
773-
let llglobal = add_unnamed_global(&cx, &foo, initializer, InternalLinkage);
774-
llvm::set_alignment(llglobal, Align::ONE);
775-
let c_section_name = CString::new(".llvm.rodata.offloading").unwrap();
776-
llvm::set_section(llglobal, &c_section_name);
777-
778-
779-
// Not actively used yet, for calling real kernels
780-
let name = format!(".offloading.entry.kernel_{num}");
781-
let ci64_0 = cx.get_const_i64(0);
782-
let ci16_1 = cx.get_const_i16(1);
783-
let elems: Vec<&llvm::Value> = vec![ci64_0, ci16_1, ci16_1, cx.get_const_i32(0), region_id, llglobal, ci64_0, ci64_0, cx.const_null(cx.type_ptr())];
784-
785-
let initializer = crate::common::named_struct(offload_entry_ty, &elems);
786-
let c_name = CString::new(name).unwrap();
787-
let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
788-
llvm::set_global_constant(llglobal, true);
789-
llvm::set_linkage(llglobal, WeakAnyLinkage);
790-
llvm::set_initializer(llglobal, initializer);
791-
llvm::set_alignment(llglobal, Align::ONE);
792-
let c_section_name = CString::new(".omp_offloading_entries").unwrap();
793-
llvm::set_section(llglobal, &c_section_name);
794-
// rustc
795-
// @.offloading.entry.kernel_3 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_3.region_id, ptr @.offloading.entry_name.3, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1
796-
// clang
797-
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
798-
799-
800-
//
801-
// 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
802-
// 2. @.offload_maptypes
803-
// 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
804-
// 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
805-
// 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
806-
o_types
807-
}
808-
809-
810-
// For each kernel *call*, we now use some of our previous declared globals to move data to and from
811-
// the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
812-
// from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
813-
// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
814-
// Since in our frontend users (by default) don't have to specify data transfer, this is something
815-
// we should optimize in the future! We also assume that everything should be copied back and forth,
816-
// but sometimes we can directly zero-allocate on the device and only move back, or if something is
817-
// immutable, we might only copy it to the device, but not back.
818-
//
819-
// Current steps:
820-
// 0. Alloca some variables for the following steps
821-
// 1. set insert point before kernel call.
822-
// 2. generate all the GEPS and stores, to be used in 3)
823-
// 3. generate __tgt_target_data_begin calls to move data to the GPU
824-
//
825-
// unchanged: keep kernel call. Later move the kernel to the GPU
826-
//
827-
// 4. set insert point after kernel call.
828-
// 5. generate all the GEPS and stores, to be used in 6)
829-
// 6. generate __tgt_target_data_end calls to move data from the GPU
830-
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
831-
832-
let main_fn = cx.get_function("main");
833-
if let Some(main_fn) = main_fn {
834-
let kernel_name = "kernel_1";
835-
let call = unsafe {llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())};
836-
let kernel_call = if call.is_some() {
837-
dbg!("found kernel call");
838-
call.unwrap()
839-
} else {
840-
return;
841-
};
842-
let kernel_call_bb = unsafe {llvm::LLVMGetInstructionParent(kernel_call)};
843-
let called = unsafe {llvm::LLVMGetCalledValue(kernel_call).unwrap()};
844-
let mut builder = SBuilder::build(cx, kernel_call_bb);
845-
846-
let types = cx.func_params_types(cx.get_type_of_global(called));
847-
dbg!(&types);
848-
let num_args = types.len() as u64;
849-
let mut names: Vec<&llvm::Value> = Vec::with_capacity(num_args as usize);
850-
851-
// Step 0)
852-
unsafe{llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn)};
853-
let ty = cx.type_array(cx.type_ptr(), num_args);
854-
// Baseptr are just the input pointer to the kernel, stored in a local alloca
855-
let a1 = builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
856-
// Ptrs are the result of a gep into the baseptr, at least for our trivial types.
857-
let a2 = builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs");
858-
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
859-
let ty2 = cx.type_array(cx.type_i64(), num_args);
860-
let a4 = builder.my_alloca2(ty2, Align::EIGHT, ".offload_sizes");
861-
// Now we allocate once per function param, a copy to be passed to one of our maps.
862-
let mut vals = vec![];
863-
let mut geps = vec![];
864-
let i32_0 = cx.get_const_i32(0);
865-
for (index, in_ty) in types.iter().enumerate() {
866-
// get function arg, store it into the alloca, and read it.
867-
let p = llvm::get_param(called, index as u32);
868-
let name = llvm::get_value_name(p);
869-
let name = str::from_utf8(name).unwrap();
870-
let arg_name = CString::new(format!("{name}.addr")).unwrap();
871-
let alloca = unsafe {llvm::LLVMBuildAlloca(builder.llbuilder, in_ty, arg_name.as_ptr())};
872-
builder.store(p, alloca, Align::EIGHT);
873-
let val = builder.load(in_ty, alloca, Align::EIGHT);
874-
let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
875-
vals.push(val);
876-
geps.push(gep);
877-
}
878-
879-
880-
// Step 1)
881-
unsafe {llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call)};
882-
for i in 0..num_args {
883-
let idx = cx.get_const_i32(i);
884-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
885-
builder.store(vals[i as usize], gep1, Align::EIGHT);
886-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
887-
builder.store(geps[i as usize], gep2, Align::EIGHT);
888-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
889-
builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
890-
}
891-
892-
// Step 2)
893-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
894-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
895-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
896-
897-
let nullptr = cx.const_null(cx.type_ptr());
898-
let o_type = o_types[0];
899-
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(num_args), gep1, gep2, gep3, o_type, nullptr, nullptr];
900-
builder.call(fn_ty, begin, &args, None);
901-
902-
// Step 4)
903-
unsafe {llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call)};
904-
905-
let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
906-
let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
907-
let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
908-
909-
let nullptr = cx.const_null(cx.type_ptr());
910-
let o_type = o_types[0];
911-
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(num_args), gep1, gep2, gep3, o_type, nullptr, nullptr];
912-
builder.call(fn_ty, end, &args, None);
913-
914-
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
915-
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
916-
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
917-
// What is @1? Random but fixed:
918-
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
919-
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
920-
}
921-
}
922636

923637
pub(crate) fn run_pass_manager(
924638
cgcx: &CodegenContext<LlvmCodegenBackend>,
@@ -945,7 +659,8 @@ pub(crate) fn run_pass_manager(
945659
// We then run the llvm_optimize function a second time, to optimize the code which we generated
946660
// in the enzyme differentiation pass.
947661
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
948-
let enable_gpu = true;//config.offload.contains(&config::Offload::Enable);
662+
let enable_gpu = config.offload.contains(&config::Offload::Enable);
663+
dbg!(&enable_gpu);
949664
let stage = if thin {
950665
write::AutodiffStage::PreAD
951666
} else {
@@ -961,30 +676,9 @@ pub(crate) fn run_pass_manager(
961676
}
962677

963678
if cfg!(llvm_enzyme) && enable_gpu && !thin {
964-
// first we need to add all the fun to the host module
965-
// %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
966-
// %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
967679
let cx =
968680
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
969-
if cx.get_function("gen_tgt_offload").is_some() {
970-
971-
let (offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals(&cx);
972-
973-
dbg!("created struct");
974-
let mut o_types = vec![];
975-
let mut kernels = vec![];
976-
for num in 0..9 {
977-
let kernel = cx.get_function(&format!("kernel_{num}"));
978-
if let Some(kernel) = kernel{
979-
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
980-
kernels.push(kernel);
981-
}
982-
}
983-
dbg!("gen_call_handling");
984-
gen_call_handling(&cx, &kernels, at_one, begin, update, end, fn_ty, &o_types);
985-
} else {
986-
dbg!("no marker found");
987-
}
681+
crate::builder::gpu_offload::handle_gpu_code(&cx);
988682
}
989683

990684
if cfg!(llvm_enzyme) && enable_ad && !thin {

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Deref;
33
use std::{iter, ptr};
44

55
pub(crate) mod autodiff;
6+
pub(crate) mod gpu_offload;
67

78
use libc::{c_char, c_uint, size_t};
89
use rustc_abi as abi;

0 commit comments

Comments
 (0)