@@ -653,6 +653,7 @@ pub(crate) fn run_pass_manager(
653
653
// We then run the llvm_optimize function a second time, to optimize the code which we generated
654
654
// in the enzyme differentiation pass.
655
655
let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
656
+ let enable_gpu = true ; //config.offload.contains(&config::Offload::Enable);
656
657
let stage = if thin {
657
658
write:: AutodiffStage :: PreAD
658
659
} else {
@@ -667,6 +668,114 @@ pub(crate) fn run_pass_manager(
667
668
write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
668
669
}
669
670
671
+ if cfg ! ( llvm_enzyme) && enable_gpu && !thin {
672
+ // first we need to add all the fun to the host module
673
+ // %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
674
+ // %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
675
+ let cx =
676
+ SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
677
+ if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
678
+ let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
679
+ let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
680
+ let tptr = cx. type_ptr ( ) ;
681
+ let ti64 = cx. type_i64 ( ) ;
682
+ let ti32 = cx. type_i32 ( ) ;
683
+ let ti16 = cx. type_i16 ( ) ;
684
+ let tarr = cx. type_array ( ti32, 3 ) ;
685
+
686
+ let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
687
+ let kernel_elements = vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
688
+
689
+ cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
690
+ cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
691
+ let global = cx. declare_global ( "my_struct_global" , offload_entry_ty) ;
692
+ let global = cx. declare_global ( "my_struct_global2" , kernel_arguments_ty) ;
693
+ dbg ! ( & offload_entry_ty) ;
694
+ dbg ! ( & kernel_arguments_ty) ;
695
+ //LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy};
696
+ //LLVMStructSetBody(structTy, elements, 9, 0);
697
+ dbg ! ( "created struct" ) ;
698
+ for num in 0 ..5 {
699
+ if !cx. get_function ( & format ! ( "kernel_{num}" ) ) . is_some ( ) {
700
+ continue ;
701
+ }
702
+ //for function in cx.get_functions() {
703
+ //if !attributes::has_attr(function, Function, llvm::AttributeKind::OptimizeForSize) {
704
+ // dbg!("skipping minsize fnc");
705
+ // dbg!(&function);
706
+ // // print fnc name
707
+ // let enzyme_marker = "minsize";
708
+ // if attributes::has_string_attr(function, enzyme_marker) {
709
+ // dbg!("found minsize str");
710
+ // }
711
+ // continue;
712
+
713
+ let size_name = format ! ( ".offload_sizes.{num}" ) ;
714
+ let size_ty = cx. type_array ( ti64, 4 ) ;
715
+ //let size_val = vec![8i64,0,16,0];
716
+ let c_val_8 = cx. get_const_i64 ( 8 ) ;
717
+ let c_val_0 = cx. get_const_i64 ( 0 ) ;
718
+ let c_val_16 = cx. get_const_i64 ( 16 ) ;
719
+ let size_val = vec ! [ c_val_8, c_val_0, c_val_16, c_val_0] ;
720
+
721
+ //let val = cx.define_global(&size_name, size_ty).unwrap();
722
+ //dbg!(&val);
723
+ //let section_var = cx
724
+ // .define_global(section_var_name, llvm_type)
725
+ // .unwrap_or_else(|| bug!("symbol `{}` is already defined", section_var_name));
726
+ //llvm::set_section(section_var, c".debug_gdb_scripts");
727
+ //llvm::set_initializer(section_var, cx.const_bytes(section_contents));
728
+ //llvm::LLVMSetGlobalConstant(section_var, llvm::True);
729
+ //llvm::set_linkage(section_var, llvm::Linkage::LinkOnceODRLinkage);
730
+ //// This should make sure that the whole section is not larger than
731
+ //// the string it contains. Otherwise we get a warning from GDB.
732
+ //llvm::LLVMSetAlignment(section_var, 1);
733
+ //llvm::set_initializer(val, cx.const_bytes(size_val.as_slice()));
734
+ let initializer = cx. const_array ( ti64, & size_val) ;
735
+ let name = format ! ( ".offload_sizes.{num}" ) ;
736
+ let c_name = CString :: new ( name) . unwrap ( ) ;
737
+ let array = llvm:: add_global ( cx. llmod , cx. val_ty ( initializer) , & c_name ) ;
738
+ llvm:: set_global_constant ( array, true ) ;
739
+ unsafe { llvm:: LLVMSetUnnamedAddress ( array, llvm:: UnnamedAddr :: Global ) } ;
740
+ llvm:: set_linkage ( array, llvm:: Linkage :: PrivateLinkage ) ;
741
+ llvm:: set_initializer ( array, initializer) ;
742
+ dbg ! ( & array) ;
743
+ // 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
744
+ // 2. @.offload_maptypes
745
+ // 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
746
+ // 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
747
+ // 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
748
+ }
749
+ // @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
750
+ // @.offload_sizes = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
751
+ // @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
752
+ // @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id = weak constant i8 0
753
+ // @.offload_sizes.1 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
754
+ // @.offload_maptypes.2 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
755
+ // @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id = weak constant i8 0
756
+ // @.offload_sizes.3 = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
757
+ // @.offload_maptypes.4 = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
758
+ // @.offload_sizes.5 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
759
+ // @.offload_maptypes.6 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
760
+ // @_ZSt4cout = external global %"class.std::basic_ostream", align 8
761
+ // @.str = private unnamed_addr constant [3 x i8] c"hi\00", align 1
762
+ // @.offload_sizes.7 = private unnamed_addr constant [2 x i64] [i64 16384, i64 16384]
763
+ // @.offload_maptypes.8 = private unnamed_addr constant [2 x i64] [i64 1, i64 3]
764
+ // @.str.9 = private unnamed_addr constant [3 x i8] c"ho\00", align 1
765
+ // @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
766
+ // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
767
+ // @.offloading.entry_name.10 = internal unnamed_addr constant [67 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13\00", section ".llvm.rodata.offloading", align 1
768
+ // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3barPSt7complexIdES1_S0_m_l13.region_id, ptr @.offloading.entry_name.10, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
769
+ // @.offloading.entry_name.11 = internal unnamed_addr constant [69 x i8] c"__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19\00", section ".llvm.rodata.offloading", align 1
770
+ // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z5zaxpyPSt7complexIdES1_S0_m_l19.region_id, ptr @.offloading.entry_name.11, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
771
+ // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr @_GLOBAL__sub_I_zaxpy.cpp, ptr null }]
772
+ } else {
773
+ dbg ! ( "no marker found" ) ;
774
+ }
775
+ } else {
776
+ dbg ! ( "Not creating struct" ) ;
777
+ }
778
+
670
779
if cfg ! ( llvm_enzyme) && enable_ad && !thin {
671
780
let cx =
672
781
SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
0 commit comments