@@ -630,53 +630,7 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
630
630
llvm:: set_rust_rules ( true ) ;
631
631
}
632
632
633
- pub ( crate ) fn run_pass_manager (
634
- cgcx : & CodegenContext < LlvmCodegenBackend > ,
635
- dcx : DiagCtxtHandle < ' _ > ,
636
- module : & mut ModuleCodegen < ModuleLlvm > ,
637
- thin : bool ,
638
- ) -> Result < ( ) , FatalError > {
639
- let _timer = cgcx. prof . generic_activity_with_arg ( "LLVM_lto_optimize" , & * module. name ) ;
640
- let config = cgcx. config ( module. kind ) ;
641
-
642
- // Now we have one massive module inside of llmod. Time to run the
643
- // LTO-specific optimization passes that LLVM provides.
644
- //
645
- // This code is based off the code found in llvm's LTO code generator:
646
- // llvm/lib/LTO/LTOCodeGenerator.cpp
647
- debug ! ( "running the pass manager" ) ;
648
- let opt_stage = if thin { llvm:: OptStage :: ThinLTO } else { llvm:: OptStage :: FatLTO } ;
649
- let opt_level = config. opt_level . unwrap_or ( config:: OptLevel :: No ) ;
650
-
651
- // The PostAD behavior is the same that we would have if no autodiff was used.
652
- // It will run the default optimization pipeline. If AD is enabled we select
653
- // the DuringAD stage, which will disable vectorization and loop unrolling, and
654
- // schedule two autodiff optimization + differentiation passes.
655
- // We then run the llvm_optimize function a second time, to optimize the code which we generated
656
- // in the enzyme differentiation pass.
657
- let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
658
- let enable_gpu = true ; //config.offload.contains(&config::Offload::Enable);
659
- let stage = if thin {
660
- write:: AutodiffStage :: PreAD
661
- } else {
662
- if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD }
663
- } ;
664
-
665
- if enable_ad {
666
- enable_autodiff_settings ( & config. autodiff ) ;
667
- }
668
-
669
- unsafe {
670
- write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
671
- }
672
-
673
- if cfg ! ( llvm_enzyme) && enable_gpu && !thin {
674
- // first we need to add all the fun to the host module
675
- // %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
676
- // %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
677
- let cx =
678
- SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
679
- if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
633
+ fn gen_globals < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> & ' ll llvm:: Type {
680
634
let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
681
635
let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
682
636
let tptr = cx. type_ptr ( ) ;
@@ -686,6 +640,23 @@ pub(crate) fn run_pass_manager(
686
640
let ti8 = cx. type_i8 ( ) ;
687
641
let tarr = cx. type_array ( ti32, 3 ) ;
688
642
643
+ // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
644
+ let unknown_txt = ";unknown;unknown;0;0;;" ;
645
+ let c_entry_name = CString :: new ( unknown_txt) . unwrap ( ) ;
646
+ let c_val = c_entry_name. as_bytes_with_nul ( ) ;
647
+ let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
648
+ let at_zero = add_unnamed_global ( & cx, & "" , initializer, PrivateLinkage ) ;
649
+ llvm:: set_alignment ( at_zero, rustc_abi:: Align :: ONE ) ;
650
+
651
+ // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
652
+ let struct_ident_ty = cx. type_named_struct ( "struct.ident_t" ) ;
653
+ let struct_elems: Vec < & llvm:: Value > = vec ! [ cx. get_const_i32( 0 ) , cx. get_const_i32( 2 ) , cx. get_const_i32( 0 ) , cx. get_const_i32( 22 ) , at_zero] ;
654
+ let struct_elems_ty: Vec < _ > = struct_elems. iter ( ) . map ( |& x| cx. val_ty ( x) ) . collect ( ) ;
655
+ let initializer = crate :: common:: named_struct ( struct_ident_ty, & struct_elems) ;
656
+ cx. set_struct_body ( struct_ident_ty, & struct_elems_ty, false ) ;
657
+ let at_one = add_unnamed_global ( & cx, & "" , initializer, PrivateLinkage ) ;
658
+ llvm:: set_alignment ( at_one, rustc_abi:: Align :: EIGHT ) ;
659
+
689
660
// coppied from LLVM
690
661
// typedef struct {
691
662
// uint64_t Reserved;
@@ -735,38 +706,38 @@ pub(crate) fn run_pass_manager(
735
706
attributes:: apply_to_llfn ( bar, Function , & [ nounwind] ) ;
736
707
attributes:: apply_to_llfn ( baz, Function , & [ nounwind] ) ;
737
708
738
- dbg ! ( "created struct" ) ;
739
- for num in 0 ..9 {
740
- if !cx. get_function ( & format ! ( "kernel_{num}" ) ) . is_some ( ) {
741
- continue ;
742
- }
709
+ offload_entry_ty
710
+ }
743
711
744
- fn add_priv_unnamed_arr < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , vals : & [ u64 ] ) -> & ' ll llvm:: Value {
745
- let ti64 = cx. type_i64 ( ) ;
746
- let size_ty = cx. type_array ( ti64, vals. len ( ) as u64 ) ;
747
- let mut size_val = Vec :: with_capacity ( vals. len ( ) ) ;
748
- for & val in vals {
749
- size_val. push ( cx. get_const_i64 ( val) ) ;
750
- }
751
- let initializer = cx. const_array ( ti64, & size_val) ;
752
- add_unnamed_global ( cx, name, initializer, PrivateLinkage )
753
- }
712
+ fn add_priv_unnamed_arr < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , vals : & [ u64 ] ) -> & ' ll llvm:: Value {
713
+ let ti64 = cx. type_i64 ( ) ;
714
+ let size_ty = cx. type_array ( ti64, vals. len ( ) as u64 ) ;
715
+ let mut size_val = Vec :: with_capacity ( vals. len ( ) ) ;
716
+ for & val in vals {
717
+ size_val. push ( cx. get_const_i64 ( val) ) ;
718
+ }
719
+ let initializer = cx. const_array ( ti64, & size_val) ;
720
+ add_unnamed_global ( cx, name, initializer, PrivateLinkage )
721
+ }
754
722
755
- fn add_global < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , initializer : & ' ll llvm:: Value , l : Linkage ) -> & ' ll llvm:: Value {
756
- let c_name = CString :: new ( name) . unwrap ( ) ;
757
- let llglobal: & ' ll llvm:: Value = llvm:: add_global ( cx. llmod , cx. val_ty ( initializer) , & c_name) ;
758
- llvm:: set_global_constant ( llglobal, true ) ;
759
- llvm:: set_linkage ( llglobal, l) ;
760
- llvm:: set_initializer ( llglobal, initializer) ;
761
- llglobal
762
- }
723
+ fn add_unnamed_global < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , initializer : & ' ll llvm:: Value , l : Linkage ) -> & ' ll llvm:: Value {
724
+ let llglobal = add_global ( cx, name, initializer, l) ;
725
+ unsafe { llvm:: LLVMSetUnnamedAddress ( llglobal, llvm:: UnnamedAddr :: Global ) } ;
726
+ llglobal
727
+ }
763
728
764
- fn add_unnamed_global < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , initializer : & ' ll llvm:: Value , l : Linkage ) -> & ' ll llvm:: Value {
765
- let llglobal = add_global ( cx, name, initializer, l) ;
766
- unsafe { llvm:: LLVMSetUnnamedAddress ( llglobal, llvm:: UnnamedAddr :: Global ) } ;
767
- llglobal
768
- }
729
+ fn add_global < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , initializer : & ' ll llvm:: Value , l : Linkage ) -> & ' ll llvm:: Value {
730
+ let c_name = CString :: new ( name) . unwrap ( ) ;
731
+ let llglobal: & ' ll llvm:: Value = llvm:: add_global ( cx. llmod , cx. val_ty ( initializer) , & c_name) ;
732
+ llvm:: set_global_constant ( llglobal, true ) ;
733
+ llvm:: set_linkage ( llglobal, l) ;
734
+ llvm:: set_initializer ( llglobal, initializer) ;
735
+ llglobal
736
+ }
769
737
738
+
739
+
740
+ fn gen_define_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , offload_entry_ty : & ' ll llvm:: Type , num : i64 ) {
770
741
// We add a pair of sizes and maptypes per offloadable function.
771
742
// @.offload_maptypes = private unnamed_addr constant [4 x i64] [i64 800, i64 544, i64 547, i64 544]
772
743
let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 8u64 , 0 , 16 , 0 ] ) ;
@@ -822,7 +793,77 @@ pub(crate) fn run_pass_manager(
822
793
// 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
823
794
// 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
824
795
// 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
796
+ }
797
+
798
+ fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) {
799
+ // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
800
+ // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
801
+ // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
802
+ // What is @1? Random but fixed:
803
+ // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
804
+ // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
805
+
806
+ }
807
+
808
+ pub ( crate ) fn run_pass_manager (
809
+ cgcx : & CodegenContext < LlvmCodegenBackend > ,
810
+ dcx : DiagCtxtHandle < ' _ > ,
811
+ module : & mut ModuleCodegen < ModuleLlvm > ,
812
+ thin : bool ,
813
+ ) -> Result < ( ) , FatalError > {
814
+ let _timer = cgcx. prof . generic_activity_with_arg ( "LLVM_lto_optimize" , & * module. name ) ;
815
+ let config = cgcx. config ( module. kind ) ;
816
+
817
+ // Now we have one massive module inside of llmod. Time to run the
818
+ // LTO-specific optimization passes that LLVM provides.
819
+ //
820
+ // This code is based off the code found in llvm's LTO code generator:
821
+ // llvm/lib/LTO/LTOCodeGenerator.cpp
822
+ debug ! ( "running the pass manager" ) ;
823
+ let opt_stage = if thin { llvm:: OptStage :: ThinLTO } else { llvm:: OptStage :: FatLTO } ;
824
+ let opt_level = config. opt_level . unwrap_or ( config:: OptLevel :: No ) ;
825
+
826
+ // The PostAD behavior is the same that we would have if no autodiff was used.
827
+ // It will run the default optimization pipeline. If AD is enabled we select
828
+ // the DuringAD stage, which will disable vectorization and loop unrolling, and
829
+ // schedule two autodiff optimization + differentiation passes.
830
+ // We then run the llvm_optimize function a second time, to optimize the code which we generated
831
+ // in the enzyme differentiation pass.
832
+ let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
833
+ let enable_gpu = true ; //config.offload.contains(&config::Offload::Enable);
834
+ let stage = if thin {
835
+ write:: AutodiffStage :: PreAD
836
+ } else {
837
+ if enable_ad { write:: AutodiffStage :: DuringAD } else { write:: AutodiffStage :: PostAD }
838
+ } ;
839
+
840
+ if enable_ad {
841
+ enable_autodiff_settings ( & config. autodiff ) ;
842
+ }
843
+
844
+ unsafe {
845
+ write:: llvm_optimize ( cgcx, dcx, module, None , config, opt_level, opt_stage, stage) ?;
846
+ }
847
+
848
+ if cfg ! ( llvm_enzyme) && enable_gpu && !thin {
849
+ // first we need to add all the fun to the host module
850
+ // %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
851
+ // %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
852
+ let cx =
853
+ SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
854
+ if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
855
+
856
+ let offload_entry_ty = gen_globals ( & cx) ;
857
+
858
+ dbg ! ( "created struct" ) ;
859
+ for num in 0 ..9 {
860
+ if !cx. get_function ( & format ! ( "kernel_{num}" ) ) . is_some ( ) {
861
+ continue ;
862
+ }
863
+ // TODO: replace num by proper fn name
864
+ gen_define_handling ( & cx, offload_entry_ty, num) ;
825
865
}
866
+ gen_call_handling ( & cx) ;
826
867
} else {
827
868
dbg ! ( "no marker found" ) ;
828
869
}
0 commit comments