Skip to content

Commit 55581c4

Browse files
committed
wip
1 parent 2398bd6 commit 55581c4

File tree

6 files changed

+186
-39
lines changed

6 files changed

+186
-39
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ pub(crate) fn run_pass_manager(
653653
// We then run the llvm_optimize function a second time, to optimize the code which we generated
654654
// in the enzyme differentiation pass.
655655
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
656+
let enable_gpu = true;//config.offload.contains(&config::Offload::Enable);
656657
let stage = if thin {
657658
write::AutodiffStage::PreAD
658659
} else {
@@ -667,6 +668,114 @@ pub(crate) fn run_pass_manager(
667668
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
668669
}
669670

671+
if cfg!(llvm_enzyme) && enable_gpu && !thin {
672+
// first we need to add all the fun to the host module
673+
// %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
674+
// %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
675+
let cx =
676+
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
677+
if cx.get_function("gen_tgt_offload").is_some() {
678+
let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
679+
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
680+
let tptr = cx.type_ptr();
681+
let ti64 = cx.type_i64();
682+
let ti32 = cx.type_i32();
683+
let ti16 = cx.type_i16();
684+
let tarr = cx.type_array(ti32, 3);
685+
686+
let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
687+
let kernel_elements = vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
688+
689+
cx.set_struct_body(offload_entry_ty, &entry_elements, false);
690+
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
691+
let global = cx.declare_global("my_struct_global", offload_entry_ty);
692+
let global = cx.declare_global("my_struct_global2", kernel_arguments_ty);
693+
dbg!(&offload_entry_ty);
694+
dbg!(&kernel_arguments_ty);
695+
//LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy};
696+
//LLVMStructSetBody(structTy, elements, 9, 0);
697+
dbg!("created struct");
698+
for num in 0..5 {
699+
if !cx.get_function(&format!("kernel_{num}")).is_some() {
700+
continue;
701+
}
702+
//for function in cx.get_functions() {
703+
//if !attributes::has_attr(function, Function, llvm::AttributeKind::OptimizeForSize) {
704+
// dbg!("skipping minsize fnc");
705+
// dbg!(&function);
706+
// // print fnc name
707+
// let enzyme_marker = "minsize";
708+
// if attributes::has_string_attr(function, enzyme_marker) {
709+
// dbg!("found minsize str");
710+
// }
711+
// continue;
712+
713+
let size_name = format!(".offload_sizes.{num}");
714+
let size_ty = cx.type_array(ti64, 4);
715+
//let size_val = vec![8i64,0,16,0];
716+
let c_val_8 = cx.get_const_i64(8);
717+
let c_val_0 = cx.get_const_i64(0);
718+
let c_val_16 = cx.get_const_i64(16);
719+
let size_val = vec![c_val_8, c_val_0, c_val_16, c_val_0];
720+
721+
//let val = cx.define_global(&size_name, size_ty).unwrap();
722+
//dbg!(&val);
723+
//let section_var = cx
724+
// .define_global(section_var_name, llvm_type)
725+
// .unwrap_or_else(|| bug!("symbol `{}` is already defined", section_var_name));
726+
//llvm::set_section(section_var, c".debug_gdb_scripts");
727+
//llvm::set_initializer(section_var, cx.const_bytes(section_contents));
728+
//llvm::LLVMSetGlobalConstant(section_var, llvm::True);
729+
//llvm::set_linkage(section_var, llvm::Linkage::LinkOnceODRLinkage);
730+
//// This should make sure that the whole section is not larger than
731+
//// the string it contains. Otherwise we get a warning from GDB.
732+
//llvm::LLVMSetAlignment(section_var, 1);
733+
//llvm::set_initializer(val, cx.const_bytes(size_val.as_slice()));
734+
let initializer = cx.const_array(ti64, &size_val);
735+
let name = format!(".offload_sizes.{num}");
736+
let c_name = CString::new(name).unwrap();
737+
let array = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name );
738+
llvm::set_global_constant(array, true);
739+
unsafe {llvm::LLVMSetUnnamedAddress(array, llvm::UnnamedAddr::Global)};
740+
llvm::set_linkage(array, llvm::Linkage::PrivateLinkage);
741+
llvm::set_initializer(array, initializer);
742+
dbg!(&array);
743+
// 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
744+
// 2. @.offload_maptypes
745+
// 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
746+
// 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
747+
// 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
748+
}
749+
// @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
750+
// @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
751+
// @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
752+
// @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id = weak constant i8 0
753+
// @.offload_sizes.1 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
754+
// @.offload_maptypes.2 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
755+
// @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id = weak constant i8 0
756+
// @.offload_sizes.3 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
757+
// @.offload_maptypes.4 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
758+
// @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
759+
// @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
760+
// @_ZSt4cout = external global %"class.std::basic_ostream", align 8
761+
// @.str = private unnamed_addr constant [3 x i8] c"hi\00", align 1
762+
// @.offload_sizes.7 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
763+
// @.offload_maptypes.8 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
764+
// @.str.9 = private unnamed_addr constant [3 x i8] c"ho\00", align 1
765+
// @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
766+
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
767+
// @.offloading.entry_name.10 = internal unnamed_addr constant [67 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13\00", section ".llvm.rodata.offloading", align 1
768+
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id, ptr @.offloading.entry_name.10, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
769+
// @.offloading.entry_name.11 = internal unnamed_addr constant [69 x i8] c"__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19\00", section ".llvm.rodata.offloading", align 1
770+
// @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id, ptr @.offloading.entry_name.11, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
771+
// @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr @_GLOBAL__sub_I_zaxpy.cpp, ptr null }]
772+
} else {
773+
dbg!("no marker found");
774+
}
775+
} else {
776+
dbg!("Not creating struct");
777+
}
778+
670779
if cfg!(llvm_enzyme) && enable_ad && !thin {
671780
let cx =
672781
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);

compiler/rustc_codegen_llvm/src/common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BackendTypes for GenericCx<'ll, CX> {
9999
type DIVariable = &'ll llvm::debuginfo::DIVariable;
100100
}
101101

102-
impl<'ll> CodegenCx<'ll, '_> {
102+
impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
103103
pub(crate) fn const_array(&self, ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value {
104104
let len = u64::try_from(elts.len()).expect("LLVMConstArray2 elements len overflow");
105105
unsafe { llvm::LLVMConstArray2(ty, elts.as_ptr(), len) }
106106
}
107107

108108
pub(crate) fn const_bytes(&self, bytes: &[u8]) -> &'ll Value {
109-
bytes_in_context(self.llcx, bytes)
109+
bytes_in_context(self.llcx(), bytes)
110110
}
111111

112112
pub(crate) fn const_get_elt(&self, v: &'ll Value, idx: u64) -> &'ll Value {

compiler/rustc_codegen_llvm/src/declare.rs

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,43 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
9999
)
100100
}
101101
}
102+
103+
/// Gets declared value by name.
104+
pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> {
105+
debug!("get_declared_value(name={:?})", name);
106+
unsafe { llvm::LLVMRustGetNamedValue(self.llmod(), name.as_c_char_ptr(), name.len()) }
107+
}
108+
109+
/// Gets defined or externally defined (AvailableExternally linkage) value by
110+
/// name.
111+
pub(crate) fn get_defined_value(&self, name: &str) -> Option<&'ll Value> {
112+
self.get_declared_value(name).and_then(|val| {
113+
let declaration = llvm::is_declaration(val);
114+
if !declaration { Some(val) } else { None }
115+
})
116+
}
117+
118+
/// Declare a global with an intention to define it.
119+
///
120+
/// Use this function when you intend to define a global. This function will
121+
/// return `None` if the name already has a definition associated with it. In that
122+
/// case an error should be reported to the user, because it usually happens due
123+
/// to user’s fault (e.g., misuse of `#[no_mangle]` or `#[export_name]` attributes).
124+
pub(crate) fn define_global(&self, name: &str, ty: &'ll Type) -> Option<&'ll Value> {
125+
if self.get_defined_value(name).is_some() {
126+
None
127+
} else {
128+
Some(self.declare_global(name, ty))
129+
}
130+
}
131+
132+
/// Declare a private global
133+
///
134+
/// Use this function when you intend to define a global without a name.
135+
pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value {
136+
unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod(), ty) }
137+
}
138+
102139
}
103140

104141
impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
@@ -215,40 +252,4 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
215252

216253
llfn
217254
}
218-
219-
/// Declare a global with an intention to define it.
220-
///
221-
/// Use this function when you intend to define a global. This function will
222-
/// return `None` if the name already has a definition associated with it. In that
223-
/// case an error should be reported to the user, because it usually happens due
224-
/// to user’s fault (e.g., misuse of `#[no_mangle]` or `#[export_name]` attributes).
225-
pub(crate) fn define_global(&self, name: &str, ty: &'ll Type) -> Option<&'ll Value> {
226-
if self.get_defined_value(name).is_some() {
227-
None
228-
} else {
229-
Some(self.declare_global(name, ty))
230-
}
231-
}
232-
233-
/// Declare a private global
234-
///
235-
/// Use this function when you intend to define a global without a name.
236-
pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value {
237-
unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod, ty) }
238-
}
239-
240-
/// Gets declared value by name.
241-
pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> {
242-
debug!("get_declared_value(name={:?})", name);
243-
unsafe { llvm::LLVMRustGetNamedValue(self.llmod, name.as_c_char_ptr(), name.len()) }
244-
}
245-
246-
/// Gets defined or externally defined (AvailableExternally linkage) value by
247-
/// name.
248-
pub(crate) fn get_defined_value(&self, name: &str) -> Option<&'ll Value> {
249-
self.get_declared_value(name).and_then(|val| {
250-
let declaration = llvm::is_declaration(val);
251-
if !declaration { Some(val) } else { None }
252-
})
253-
}
254255
}

compiler/rustc_codegen_ssa/src/back/write.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ pub struct ModuleConfig {
121121
pub emit_lifetime_markers: bool,
122122
pub llvm_plugins: Vec<String>,
123123
pub autodiff: Vec<config::AutoDiff>,
124+
pub offload: Vec<config::Offload>,
124125
}
125126

126127
impl ModuleConfig {
@@ -270,6 +271,7 @@ impl ModuleConfig {
270271
emit_lifetime_markers: sess.emit_lifetime_markers(),
271272
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
272273
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
274+
offload: if_regular!(sess.opts.unstable_opts.offload.clone(), vec![]),
273275
}
274276
}
275277

compiler/rustc_session/src/config.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,13 @@ pub enum CoverageLevel {
226226
Mcdc,
227227
}
228228

229+
// The different settings that the `-Z offload` flag can have.
230+
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
231+
pub enum Offload {
232+
/// Enable the llvm offload pipeline
233+
Enable,
234+
}
235+
229236
/// The different settings that the `-Z autodiff` flag can have.
230237
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
231238
pub enum AutoDiff {
@@ -3061,7 +3068,7 @@ pub(crate) mod dep_tracking {
30613068
};
30623069

30633070
use super::{
3064-
AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
3071+
AutoDiff, Offload, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions,
30653072
CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn,
30663073
InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail,
30673074
LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName,
@@ -3110,6 +3117,7 @@ pub(crate) mod dep_tracking {
31103117

31113118
impl_dep_tracking_hash_via_hash!(
31123119
AutoDiff,
3120+
Offload,
31133121
bool,
31143122
usize,
31153123
NonZero<usize>,

compiler/rustc_session/src/options.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ mod desc {
712712
pub(crate) const parse_list_with_polarity: &str =
713713
"a comma-separated list of strings, with elements beginning with + or -";
714714
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
715+
pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`";
715716
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
716717
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
717718
pub(crate) const parse_number: &str = "a number";
@@ -1343,6 +1344,27 @@ pub mod parse {
13431344
}
13441345
}
13451346

1347+
pub(crate) fn parse_offload(slot: &mut Vec<Offload>, v: Option<&str>) -> bool {
1348+
let Some(v) = v else {
1349+
*slot = vec![];
1350+
return true;
1351+
};
1352+
let mut v: Vec<&str> = v.split(",").collect();
1353+
v.sort_unstable();
1354+
for &val in v.iter() {
1355+
let variant = match val {
1356+
"Enable" => Offload::Enable,
1357+
_ => {
1358+
// FIXME(ZuseZ4): print an error saying which value is not recognized
1359+
return false;
1360+
}
1361+
};
1362+
slot.push(variant);
1363+
}
1364+
1365+
true
1366+
}
1367+
13461368
pub(crate) fn parse_autodiff(slot: &mut Vec<AutoDiff>, v: Option<&str>) -> bool {
13471369
let Some(v) = v else {
13481370
*slot = vec![];
@@ -2372,6 +2394,11 @@ options! {
23722394
"do not use unique names for text and data sections when -Z function-sections is used"),
23732395
normalize_docs: bool = (false, parse_bool, [TRACKED],
23742396
"normalize associated items in rustdoc when generating documentation"),
2397+
offload: Vec<crate::config::Offload> = (Vec::new(), parse_offload, [TRACKED],
2398+
"a list of offload flags to enable
2399+
Mandatory setting:
2400+
`=Enable`
2401+
Currently the only option available"),
23752402
on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED],
23762403
"behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"),
23772404
oom: OomStrategy = (OomStrategy::Abort, parse_oom_strategy, [TRACKED],

0 commit comments

Comments
 (0)