Skip to content

Commit d1ef755

Browse files
committed
also generate allocas in the caller of kernels
1 parent bbc56b4 commit d1ef755

File tree

4 files changed

+111
-74
lines changed

4 files changed

+111
-74
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::Arc;
77
use std::{io, iter, slice};
88

99
use object::read::archive::ArchiveFile;
10+
use rustc_abi::{Align, Size};
1011
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule, ThinShared};
1112
use rustc_codegen_ssa::back::symbol_export;
1213
use rustc_codegen_ssa::back::write::{CodegenContext, FatLtoInput};
@@ -648,7 +649,7 @@ fn gen_globals<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, &'ll llvm::Value
648649
let c_val = c_entry_name.as_bytes_with_nul();
649650
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
650651
let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
651-
llvm::set_alignment(at_zero, rustc_abi::Align::ONE);
652+
llvm::set_alignment(at_zero, Align::ONE);
652653

653654
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
654655
let struct_ident_ty = cx.type_named_struct("struct.ident_t");
@@ -657,7 +658,7 @@ fn gen_globals<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, &'ll llvm::Value
657658
let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
658659
cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
659660
let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
660-
llvm::set_alignment(at_one, rustc_abi::Align::EIGHT);
661+
llvm::set_alignment(at_one, Align::EIGHT);
661662

662663
// coppied from LLVM
663664
// typedef struct {
@@ -711,7 +712,7 @@ fn gen_globals<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, &'ll llvm::Value
711712
(offload_entry_ty, at_one, foo, bar, baz, mapper_fn_ty)
712713
}
713714

714-
fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value{
715+
fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
715716
let ti64 = cx.type_i64();
716717
let size_ty = cx.type_array(ti64, vals.len() as u64);
717718
let mut size_val = Vec::with_capacity(vals.len());
@@ -739,7 +740,7 @@ fn add_global<'ll>(cx: &SimpleCx<'ll>, name: &str, initializer: &'ll llvm::Value
739740

740741

741742

742-
fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::Type, num: i64) {
743+
fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::Type, num: i64) -> &'ll llvm::Value {
743744
// We add a pair of sizes and maptypes per offloadable function.
744745
// @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
745746
let o_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![8u64,0,16,0]);
@@ -763,7 +764,7 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::
763764

764765
let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
765766
let llglobal = add_unnamed_global(&cx, &foo, initializer, InternalLinkage);
766-
llvm::set_alignment(llglobal, rustc_abi::Align::ONE);
767+
llvm::set_alignment(llglobal, Align::ONE);
767768
let c_section_name = CString::new(".llvm.rodata.offloading").unwrap();
768769
llvm::set_section(llglobal, &c_section_name);
769770

@@ -780,7 +781,7 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::
780781
llvm::set_global_constant(llglobal, true);
781782
llvm::set_linkage(llglobal, WeakAnyLinkage);
782783
llvm::set_initializer(llglobal, initializer);
783-
llvm::set_alignment(llglobal, rustc_abi::Align::ONE);
784+
llvm::set_alignment(llglobal, Align::ONE);
784785
let c_section_name = CString::new(".omp_offloading_entries").unwrap();
785786
llvm::set_section(llglobal, &c_section_name);
786787
// rustc
@@ -795,9 +796,10 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, offload_entry_ty: &'ll llvm::
795796
// 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
796797
// 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
797798
// 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
799+
o_types
798800
}
799801

800-
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type) {
802+
fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, begin: &'ll llvm::Value, update: &'ll llvm::Value, end: &'ll llvm::Value, fn_ty: &'ll llvm::Type, o_types: &[&'ll llvm::Value]) {
801803

802804
let main_fn = cx.get_function("main");
803805
if let Some(main_fn) = main_fn {
@@ -811,20 +813,33 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, s_ident_t: &'ll llvm::Value, be
811813
};
812814
let kernel_call_bb = unsafe {llvm::LLVMGetInstructionParent(kernel_call)};
813815
let mut builder = SBuilder::build(cx, kernel_call_bb);
816+
817+
// First we generate a few variables used for the data mappers below.
818+
// %.offload_baseptrs = alloca [3 x ptr], align 8
819+
// %.offload_ptrs = alloca [3 x ptr], align 8
820+
// %.offload_mappers = alloca [3 x ptr], align 8
821+
// %.offload_sizes = alloca [3 x i64], align 8
822+
unsafe{llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn)};
823+
let ty = cx.type_array(cx.type_ptr(), 3);
824+
builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs");
825+
builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs");
826+
builder.my_alloca2(ty, Align::EIGHT, ".offload_mappers");
827+
let ty = cx.type_array(cx.type_i64(), 3);
828+
builder.my_alloca2(ty, Align::EIGHT, ".offload_sizes");
829+
830+
831+
// Now we generate the __tgt_target_data calls
814832
unsafe {llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call)};
815833
dbg!("positioned builder, ready");
816834

817835
let nullptr = cx.const_null(cx.type_ptr());
818-
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(3), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr];
836+
let o_type = o_types[0];
837+
let args = vec![s_ident_t, cx.get_const_i64(u64::MAX), cx.get_const_i32(3), nullptr, nullptr, nullptr, o_type, nullptr, nullptr];
819838
dbg!(&fn_ty);
820839
dbg!(&begin);
821840
dbg!(&args);
822841
builder.call(fn_ty, begin, &args, None);
823842
dbg!("called begin");
824-
//llty: &'ll Type,
825-
//llfn: &'ll Value,
826-
//args: &[&'ll Value],
827-
//funclet: Option<&Funclet<'ll>>,
828843

829844
// 1. set insert point before kernel call.
830845
// 2. generate all the GEPS and stores.
@@ -895,14 +910,15 @@ pub(crate) fn run_pass_manager(
895910
let (offload_entry_ty, at_one, foo, bar, baz, fn_ty) = gen_globals(&cx);
896911

897912
dbg!("created struct");
913+
let mut o_types = vec![];
898914
for num in 0..9 {
899915
if !cx.get_function(&format!("kernel_{num}")).is_some() {
900916
continue;
901917
}
902918
// TODO: replace num by proper fn name
903-
gen_define_handling(&cx, offload_entry_ty, num);
919+
o_types.push(gen_define_handling(&cx, offload_entry_ty, num));
904920
}
905-
gen_call_handling(&cx, at_one, foo, bar, baz, fn_ty);
921+
gen_call_handling(&cx, at_one, foo, bar, baz, fn_ty, &o_types);
906922
} else {
907923
dbg!("no marker found");
908924
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 75 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
8888
};
8989
call
9090
}
91+
9192
}
9293

9394
impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
@@ -118,6 +119,20 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
118119
}
119120
bx
120121
}
122+
123+
pub(crate) fn my_alloca2(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
124+
let val = unsafe {
125+
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
126+
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
127+
// Cast to default addrspace if necessary
128+
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
129+
};
130+
if name != "" {
131+
let name = std::ffi::CString::new(name).unwrap();
132+
unsafe {llvm::set_value_name(val, &name.as_bytes())};
133+
}
134+
val
135+
}
121136
}
122137

123138
/// Empty string, to be used where LLVM expects an instruction name, indicating
@@ -1261,7 +1276,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
12611276
unsafe {
12621277
llvm::LLVMBuildCleanupRet(self.llbuilder, funclet.cleanuppad(), unwind)
12631278
.expect("LLVM does not have support for cleanupret");
1264-
}
1279+
}
12651280
}
12661281

12671282
fn catch_pad(&mut self, parent: &'ll Value, args: &[&'ll Value]) -> Funclet<'ll> {
@@ -1631,14 +1646,14 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16311646
debug!(
16321647
"type mismatch in function call of {:?}. \
16331648
Expected {:?} for param {}, got {:?}; injecting bitcast",
1634-
llfn, expected_ty, i, actual_ty
1649+
llfn, expected_ty, i, actual_ty
16351650
);
16361651
self.bitcast(actual_val, expected_ty)
16371652
} else {
16381653
actual_val
16391654
}
16401655
})
1641-
.collect();
1656+
.collect();
16421657

16431658
Cow::Owned(casted_args)
16441659
}
@@ -1791,48 +1806,48 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17911806
let is_indirect_call = unsafe { llvm::LLVMRustIsNonGVFunctionPointerTy(llfn) };
17921807
if self.tcx.sess.is_sanitizer_cfi_enabled()
17931808
&& let Some(fn_abi) = fn_abi
1794-
&& is_indirect_call
1795-
{
1796-
if let Some(fn_attrs) = fn_attrs
1797-
&& fn_attrs.no_sanitize.contains(SanitizerSet::CFI)
1809+
&& is_indirect_call
17981810
{
1799-
return;
1800-
}
1811+
if let Some(fn_attrs) = fn_attrs
1812+
&& fn_attrs.no_sanitize.contains(SanitizerSet::CFI)
1813+
{
1814+
return;
1815+
}
18011816

1802-
let mut options = cfi::TypeIdOptions::empty();
1803-
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
1804-
options.insert(cfi::TypeIdOptions::GENERALIZE_POINTERS);
1805-
}
1806-
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
1807-
options.insert(cfi::TypeIdOptions::NORMALIZE_INTEGERS);
1808-
}
1817+
let mut options = cfi::TypeIdOptions::empty();
1818+
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
1819+
options.insert(cfi::TypeIdOptions::GENERALIZE_POINTERS);
1820+
}
1821+
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
1822+
options.insert(cfi::TypeIdOptions::NORMALIZE_INTEGERS);
1823+
}
18091824

1810-
let typeid = if let Some(instance) = instance {
1811-
cfi::typeid_for_instance(self.tcx, instance, options)
1812-
} else {
1813-
cfi::typeid_for_fnabi(self.tcx, fn_abi, options)
1814-
};
1815-
let typeid_metadata = self.cx.typeid_metadata(typeid).unwrap();
1816-
let dbg_loc = self.get_dbg_loc();
1817-
1818-
// Test whether the function pointer is associated with the type identifier.
1819-
let cond = self.type_test(llfn, typeid_metadata);
1820-
let bb_pass = self.append_sibling_block("type_test.pass");
1821-
let bb_fail = self.append_sibling_block("type_test.fail");
1822-
self.cond_br(cond, bb_pass, bb_fail);
1823-
1824-
self.switch_to_block(bb_fail);
1825-
if let Some(dbg_loc) = dbg_loc {
1826-
self.set_dbg_loc(dbg_loc);
1827-
}
1828-
self.abort();
1829-
self.unreachable();
1825+
let typeid = if let Some(instance) = instance {
1826+
cfi::typeid_for_instance(self.tcx, instance, options)
1827+
} else {
1828+
cfi::typeid_for_fnabi(self.tcx, fn_abi, options)
1829+
};
1830+
let typeid_metadata = self.cx.typeid_metadata(typeid).unwrap();
1831+
let dbg_loc = self.get_dbg_loc();
1832+
1833+
// Test whether the function pointer is associated with the type identifier.
1834+
let cond = self.type_test(llfn, typeid_metadata);
1835+
let bb_pass = self.append_sibling_block("type_test.pass");
1836+
let bb_fail = self.append_sibling_block("type_test.fail");
1837+
self.cond_br(cond, bb_pass, bb_fail);
1838+
1839+
self.switch_to_block(bb_fail);
1840+
if let Some(dbg_loc) = dbg_loc {
1841+
self.set_dbg_loc(dbg_loc);
1842+
}
1843+
self.abort();
1844+
self.unreachable();
18301845

1831-
self.switch_to_block(bb_pass);
1832-
if let Some(dbg_loc) = dbg_loc {
1833-
self.set_dbg_loc(dbg_loc);
1846+
self.switch_to_block(bb_pass);
1847+
if let Some(dbg_loc) = dbg_loc {
1848+
self.set_dbg_loc(dbg_loc);
1849+
}
18341850
}
1835-
}
18361851
}
18371852

18381853
// Emits KCFI operand bundles.
@@ -1847,31 +1862,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
18471862
let kcfi_bundle = if self.tcx.sess.is_sanitizer_kcfi_enabled()
18481863
&& let Some(fn_abi) = fn_abi
18491864
&& is_indirect_call
1850-
{
1851-
if let Some(fn_attrs) = fn_attrs
1852-
&& fn_attrs.no_sanitize.contains(SanitizerSet::KCFI)
18531865
{
1854-
return None;
1855-
}
1866+
if let Some(fn_attrs) = fn_attrs
1867+
&& fn_attrs.no_sanitize.contains(SanitizerSet::KCFI)
1868+
{
1869+
return None;
1870+
}
18561871

1857-
let mut options = kcfi::TypeIdOptions::empty();
1858-
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
1859-
options.insert(kcfi::TypeIdOptions::GENERALIZE_POINTERS);
1860-
}
1861-
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
1862-
options.insert(kcfi::TypeIdOptions::NORMALIZE_INTEGERS);
1863-
}
1872+
let mut options = kcfi::TypeIdOptions::empty();
1873+
if self.tcx.sess.is_sanitizer_cfi_generalize_pointers_enabled() {
1874+
options.insert(kcfi::TypeIdOptions::GENERALIZE_POINTERS);
1875+
}
1876+
if self.tcx.sess.is_sanitizer_cfi_normalize_integers_enabled() {
1877+
options.insert(kcfi::TypeIdOptions::NORMALIZE_INTEGERS);
1878+
}
18641879

1865-
let kcfi_typeid = if let Some(instance) = instance {
1866-
kcfi::typeid_for_instance(self.tcx, instance, options)
1880+
let kcfi_typeid = if let Some(instance) = instance {
1881+
kcfi::typeid_for_instance(self.tcx, instance, options)
1882+
} else {
1883+
kcfi::typeid_for_fnabi(self.tcx, fn_abi, options)
1884+
};
1885+
1886+
Some(llvm::OperandBundleBox::new("kcfi", &[self.const_u32(kcfi_typeid)]))
18671887
} else {
1868-
kcfi::typeid_for_fnabi(self.tcx, fn_abi, options)
1888+
None
18691889
};
1870-
1871-
Some(llvm::OperandBundleBox::new("kcfi", &[self.const_u32(kcfi_typeid)]))
1872-
} else {
1873-
None
1874-
};
18751890
kcfi_bundle
18761891
}
18771892

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2560,6 +2560,7 @@ unsafe extern "C" {
25602560

25612561
pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine);
25622562

2563+
pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value);
25632564
pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock);
25642565

25652566
pub(crate) fn LLVMRustSetModulePICLevel(M: &Module);

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,11 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst,
15911591
MaybeAlign(DstAlign), IsVolatile));
15921592
}
15931593

1594+
extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B,
1595+
LLVMValueRef Fn) {
1596+
Function *F = unwrap<Function>(Fn);
1597+
unwrap(B)->SetInsertPointPastAllocas(F);
1598+
}
15941599
extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B,
15951600
LLVMBasicBlockRef BB) {
15961601
auto Point = unwrap(BB)->getFirstInsertionPt();

0 commit comments

Comments
 (0)