@@ -633,292 +633,6 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
633
633
llvm:: set_rust_rules ( true ) ;
634
634
}
635
635
636
- fn gen_globals < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> ( & ' ll llvm:: Type , & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Value , & ' ll llvm:: Type ) {
637
- let offload_entry_ty = cx. type_named_struct ( "struct.__tgt_offload_entry" ) ;
638
- let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
639
- let tptr = cx. type_ptr ( ) ;
640
- let ti64 = cx. type_i64 ( ) ;
641
- let ti32 = cx. type_i32 ( ) ;
642
- let ti16 = cx. type_i16 ( ) ;
643
- let ti8 = cx. type_i8 ( ) ;
644
- let tarr = cx. type_array ( ti32, 3 ) ;
645
-
646
- // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
647
- let unknown_txt = ";unknown;unknown;0;0;;" ;
648
- let c_entry_name = CString :: new ( unknown_txt) . unwrap ( ) ;
649
- let c_val = c_entry_name. as_bytes_with_nul ( ) ;
650
- let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
651
- let at_zero = add_unnamed_global ( & cx, & "" , initializer, PrivateLinkage ) ;
652
- llvm:: set_alignment ( at_zero, Align :: ONE ) ;
653
-
654
- // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
655
- let struct_ident_ty = cx. type_named_struct ( "struct.ident_t" ) ;
656
- let struct_elems: Vec < & llvm:: Value > = vec ! [ cx. get_const_i32( 0 ) , cx. get_const_i32( 2 ) , cx. get_const_i32( 0 ) , cx. get_const_i32( 22 ) , at_zero] ;
657
- let struct_elems_ty: Vec < _ > = struct_elems. iter ( ) . map ( |& x| cx. val_ty ( x) ) . collect ( ) ;
658
- let initializer = crate :: common:: named_struct ( struct_ident_ty, & struct_elems) ;
659
- cx. set_struct_body ( struct_ident_ty, & struct_elems_ty, false ) ;
660
- let at_one = add_unnamed_global ( & cx, & "" , initializer, PrivateLinkage ) ;
661
- llvm:: set_alignment ( at_one, Align :: EIGHT ) ;
662
-
663
- // coppied from LLVM
664
- // typedef struct {
665
- // uint64_t Reserved;
666
- // uint16_t Version;
667
- // uint16_t Kind;
668
- // uint32_t Flags;
669
- // void *Address;
670
- // char *SymbolName;
671
- // uint64_t Size;
672
- // uint64_t Data;
673
- // void *AuxAddr;
674
- // } __tgt_offload_entry;
675
- let entry_elements = vec ! [ ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr] ;
676
- let kernel_elements = vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
677
-
678
- cx. set_struct_body ( offload_entry_ty, & entry_elements, false ) ;
679
- cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
680
- let global = cx. declare_global ( "my_struct_global" , offload_entry_ty) ;
681
- let global = cx. declare_global ( "my_struct_global2" , kernel_arguments_ty) ;
682
- //@my_struct_global = external global %struct.__tgt_offload_entry
683
- //@my_struct_global2 = external global %struct.__tgt_kernel_arguments
684
- dbg ! ( & offload_entry_ty) ;
685
- dbg ! ( & kernel_arguments_ty) ;
686
- //LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy};
687
- //LLVMStructSetBody(structTy, elements, 9, 0);
688
-
689
- // New, to test memtransfer
690
- // ; Function Attrs: nounwind
691
- // declare void @__tgt_target_data_begin_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3
692
- //
693
- // ; Function Attrs: nounwind
694
- // declare void @__tgt_target_data_update_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3
695
- //
696
- // ; Function Attrs: nounwind
697
- // declare void @__tgt_target_data_end_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3
698
-
699
- let mapper_begin = "__tgt_target_data_begin_mapper" ;
700
- let mapper_update = String :: from ( "__tgt_target_data_update_mapper" ) ;
701
- let mapper_end = String :: from ( "__tgt_target_data_end_mapper" ) ;
702
- let args = vec ! [ tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr] ;
703
- let mapper_fn_ty = cx. type_func ( & args, cx. type_void ( ) ) ;
704
- let foo = crate :: declare:: declare_simple_fn ( & cx, & mapper_begin, llvm:: CallConv :: CCallConv , llvm:: UnnamedAddr :: No , llvm:: Visibility :: Default , mapper_fn_ty) ;
705
- let bar = crate :: declare:: declare_simple_fn ( & cx, & mapper_update, llvm:: CallConv :: CCallConv , llvm:: UnnamedAddr :: No , llvm:: Visibility :: Default , mapper_fn_ty) ;
706
- let baz = crate :: declare:: declare_simple_fn ( & cx, & mapper_end, llvm:: CallConv :: CCallConv , llvm:: UnnamedAddr :: No , llvm:: Visibility :: Default , mapper_fn_ty) ;
707
- let nounwind = llvm:: AttributeKind :: NoUnwind . create_attr ( cx. llcx ) ;
708
- attributes:: apply_to_llfn ( foo, Function , & [ nounwind] ) ;
709
- attributes:: apply_to_llfn ( bar, Function , & [ nounwind] ) ;
710
- attributes:: apply_to_llfn ( baz, Function , & [ nounwind] ) ;
711
-
712
- ( offload_entry_ty, at_one, foo, bar, baz, mapper_fn_ty)
713
- }
714
-
715
- fn add_priv_unnamed_arr < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , vals : & [ u64 ] ) -> & ' ll llvm:: Value {
716
- let ti64 = cx. type_i64 ( ) ;
717
- let size_ty = cx. type_array ( ti64, vals. len ( ) as u64 ) ;
718
- let mut size_val = Vec :: with_capacity ( vals. len ( ) ) ;
719
- for & val in vals {
720
- size_val. push ( cx. get_const_i64 ( val) ) ;
721
- }
722
- let initializer = cx. const_array ( ti64, & size_val) ;
723
- add_unnamed_global ( cx, name, initializer, PrivateLinkage )
724
- }
725
-
726
- fn add_unnamed_global < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , initializer : & ' ll llvm:: Value , l : Linkage ) -> & ' ll llvm:: Value {
727
- let llglobal = add_global ( cx, name, initializer, l) ;
728
- unsafe { llvm:: LLVMSetUnnamedAddress ( llglobal, llvm:: UnnamedAddr :: Global ) } ;
729
- llglobal
730
- }
731
-
732
- fn add_global < ' ll > ( cx : & SimpleCx < ' ll > , name : & str , initializer : & ' ll llvm:: Value , l : Linkage ) -> & ' ll llvm:: Value {
733
- let c_name = CString :: new ( name) . unwrap ( ) ;
734
- let llglobal: & ' ll llvm:: Value = llvm:: add_global ( cx. llmod , cx. val_ty ( initializer) , & c_name) ;
735
- llvm:: set_global_constant ( llglobal, true ) ;
736
- llvm:: set_linkage ( llglobal, l) ;
737
- llvm:: set_initializer ( llglobal, initializer) ;
738
- llglobal
739
- }
740
-
741
-
742
-
743
- fn gen_define_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernel : & ' ll llvm:: Value , offload_entry_ty : & ' ll llvm:: Type , num : i64 ) -> & ' ll llvm:: Value {
744
- let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
745
- // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
746
- // reference) types.
747
- let num_ptr_types = types. iter ( ) . map ( |& x| matches ! ( cx. type_kind( x) , rustc_codegen_ssa:: common:: TypeKind :: Pointer ) ) . count ( ) ;
748
-
749
- // We do not know their size anymore at this level, so hardcode a placeholder.
750
- // A follow-up pr will track these from the frontend, where we still have Rust types.
751
- // Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes.
752
- // I decided that 1024 bytes is a great placeholder value for now.
753
- let o_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{num}" ) , & vec ! [ 1024 ; num_ptr_types] ) ;
754
- // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
755
- // or both to and from the gpu (=3). Other values shouldn't affect us for now.
756
- // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
757
- // will be 2. For now, everything is 3, untill we have our frontend set up.
758
- let o_types = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 3 ; num_ptr_types] ) ;
759
- // Next: For each function, generate these three entries. A weak constant,
760
- // the llvm.rodata entry name, and the omp_offloading_entries value
761
-
762
- // @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0
763
- // @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
764
- let name = format ! ( ".kernel_{num}.region_id" ) ;
765
- let initializer = cx. get_const_i8 ( 0 ) ;
766
- let region_id = add_unnamed_global ( & cx, & name, initializer, WeakAnyLinkage ) ;
767
-
768
- let c_entry_name = CString :: new ( format ! ( "kernel_{num}" ) ) . unwrap ( ) ;
769
- let c_val = c_entry_name. as_bytes_with_nul ( ) ;
770
- let foo = format ! ( ".offloading.entry_name.{num}" ) ;
771
-
772
- let initializer = crate :: common:: bytes_in_context ( cx. llcx , c_val) ;
773
- let llglobal = add_unnamed_global ( & cx, & foo, initializer, InternalLinkage ) ;
774
- llvm:: set_alignment ( llglobal, Align :: ONE ) ;
775
- let c_section_name = CString :: new ( ".llvm.rodata.offloading" ) . unwrap ( ) ;
776
- llvm:: set_section ( llglobal, & c_section_name) ;
777
-
778
-
779
- // Not actively used yet, for calling real kernels
780
- let name = format ! ( ".offloading.entry.kernel_{num}" ) ;
781
- let ci64_0 = cx. get_const_i64 ( 0 ) ;
782
- let ci16_1 = cx. get_const_i16 ( 1 ) ;
783
- let elems: Vec < & llvm:: Value > = vec ! [ ci64_0, ci16_1, ci16_1, cx. get_const_i32( 0 ) , region_id, llglobal, ci64_0, ci64_0, cx. const_null( cx. type_ptr( ) ) ] ;
784
-
785
- let initializer = crate :: common:: named_struct ( offload_entry_ty, & elems) ;
786
- let c_name = CString :: new ( name) . unwrap ( ) ;
787
- let llglobal = llvm:: add_global ( cx. llmod , offload_entry_ty, & c_name) ;
788
- llvm:: set_global_constant ( llglobal, true ) ;
789
- llvm:: set_linkage ( llglobal, WeakAnyLinkage ) ;
790
- llvm:: set_initializer ( llglobal, initializer) ;
791
- llvm:: set_alignment ( llglobal, Align :: ONE ) ;
792
- let c_section_name = CString :: new ( ".omp_offloading_entries" ) . unwrap ( ) ;
793
- llvm:: set_section ( llglobal, & c_section_name) ;
794
- // rustc
795
- // @.offloading.entry.kernel_3 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_3.region_id, ptr @.offloading.entry_name.3, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1
796
- // clang
797
- // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
798
-
799
-
800
- //
801
- // 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0]
802
- // 2. @.offload_maptypes
803
- // 3. @.__omp_offloading_<hash>_fnc_name_<hash> = weak constant i8 0
804
- // 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1
805
- // 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1
806
- o_types
807
- }
808
-
809
-
810
- // For each kernel *call*, we now use some of our previous declared globals to move data to and from
811
- // the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
812
- // from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
813
- // If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
814
- // Since in our frontend users (by default) don't have to specify data transfer, this is something
815
- // we should optimize in the future! We also assume that everything should be copied back and forth,
816
- // but sometimes we can directly zero-allocate on the device and only move back, or if something is
817
- // immutable, we might only copy it to the device, but not back.
818
- //
819
- // Current steps:
820
- // 0. Alloca some variables for the following steps
821
- // 1. set insert point before kernel call.
822
- // 2. generate all the GEPS and stores, to be used in 3)
823
- // 3. generate __tgt_target_data_begin calls to move data to the GPU
824
- //
825
- // unchanged: keep kernel call. Later move the kernel to the GPU
826
- //
827
- // 4. set insert point after kernel call.
828
- // 5. generate all the GEPS and stores, to be used in 6)
829
- // 6. generate __tgt_target_data_end calls to move data from the GPU
830
- fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernels : & [ & ' ll llvm:: Value ] , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
831
-
832
- let main_fn = cx. get_function ( "main" ) ;
833
- if let Some ( main_fn) = main_fn {
834
- let kernel_name = "kernel_1" ;
835
- let call = unsafe { llvm:: LLVMRustGetFunctionCall ( main_fn, kernel_name. as_c_char_ptr ( ) , kernel_name. len ( ) ) } ;
836
- let kernel_call = if call. is_some ( ) {
837
- dbg ! ( "found kernel call" ) ;
838
- call. unwrap ( )
839
- } else {
840
- return ;
841
- } ;
842
- let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
843
- let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) . unwrap ( ) } ;
844
- let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
845
-
846
- let types = cx. func_params_types ( cx. get_type_of_global ( called) ) ;
847
- dbg ! ( & types) ;
848
- let num_args = types. len ( ) as u64 ;
849
- let mut names: Vec < & llvm:: Value > = Vec :: with_capacity ( num_args as usize ) ;
850
-
851
- // Step 0)
852
- unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
853
- let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
854
- // Baseptr are just the input pointer to the kernel, stored in a local alloca
855
- let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
856
- // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
857
- let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
858
- // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
859
- let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
860
- let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
861
- // Now we allocate once per function param, a copy to be passed to one of our maps.
862
- let mut vals = vec ! [ ] ;
863
- let mut geps = vec ! [ ] ;
864
- let i32_0 = cx. get_const_i32 ( 0 ) ;
865
- for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
866
- // get function arg, store it into the alloca, and read it.
867
- let p = llvm:: get_param ( called, index as u32 ) ;
868
- let name = llvm:: get_value_name ( p) ;
869
- let name = str:: from_utf8 ( name) . unwrap ( ) ;
870
- let arg_name = CString :: new ( format ! ( "{name}.addr" ) ) . unwrap ( ) ;
871
- let alloca = unsafe { llvm:: LLVMBuildAlloca ( builder. llbuilder , in_ty, arg_name. as_ptr ( ) ) } ;
872
- builder. store ( p, alloca, Align :: EIGHT ) ;
873
- let val = builder. load ( in_ty, alloca, Align :: EIGHT ) ;
874
- let gep = builder. inbounds_gep ( cx. type_f32 ( ) , val, & [ i32_0] ) ;
875
- vals. push ( val) ;
876
- geps. push ( gep) ;
877
- }
878
-
879
-
880
- // Step 1)
881
- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
882
- for i in 0 ..num_args {
883
- let idx = cx. get_const_i32 ( i) ;
884
- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, idx] ) ;
885
- builder. store ( vals[ i as usize ] , gep1, Align :: EIGHT ) ;
886
- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, idx] ) ;
887
- builder. store ( geps[ i as usize ] , gep2, Align :: EIGHT ) ;
888
- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, idx] ) ;
889
- builder. store ( cx. get_const_i64 ( 1024 ) , gep3, Align :: EIGHT ) ;
890
- }
891
-
892
- // Step 2)
893
- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
894
- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
895
- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
896
-
897
- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
898
- let o_type = o_types[ 0 ] ;
899
- let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
900
- builder. call ( fn_ty, begin, & args, None ) ;
901
-
902
- // Step 4)
903
- unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
904
-
905
- let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
906
- let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
907
- let gep3 = builder. inbounds_gep ( ty2, a4, & [ i32_0, i32_0] ) ;
908
-
909
- let nullptr = cx. const_null ( cx. type_ptr ( ) ) ;
910
- let o_type = o_types[ 0 ] ;
911
- let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
912
- builder. call ( fn_ty, end, & args, None ) ;
913
-
914
- // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
915
- // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
916
- // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
917
- // What is @1? Random but fixed:
918
- // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
919
- // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
920
- }
921
- }
922
636
923
637
pub ( crate ) fn run_pass_manager (
924
638
cgcx : & CodegenContext < LlvmCodegenBackend > ,
@@ -945,7 +659,8 @@ pub(crate) fn run_pass_manager(
945
659
// We then run the llvm_optimize function a second time, to optimize the code which we generated
946
660
// in the enzyme differentiation pass.
947
661
let enable_ad = config. autodiff . contains ( & config:: AutoDiff :: Enable ) ;
948
- let enable_gpu = true ; //config.offload.contains(&config::Offload::Enable);
662
+ let enable_gpu = config. offload . contains ( & config:: Offload :: Enable ) ;
663
+ dbg ! ( & enable_gpu) ;
949
664
let stage = if thin {
950
665
write:: AutodiffStage :: PreAD
951
666
} else {
@@ -961,30 +676,9 @@ pub(crate) fn run_pass_manager(
961
676
}
962
677
963
678
if cfg ! ( llvm_enzyme) && enable_gpu && !thin {
964
- // first we need to add all the fun to the host module
965
- // %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr }
966
- // %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 }
967
679
let cx =
968
680
SimpleCx :: new ( module. module_llvm . llmod ( ) , & module. module_llvm . llcx , cgcx. pointer_size ) ;
969
- if cx. get_function ( "gen_tgt_offload" ) . is_some ( ) {
970
-
971
- let ( offload_entry_ty, at_one, begin, update, end, fn_ty) = gen_globals ( & cx) ;
972
-
973
- dbg ! ( "created struct" ) ;
974
- let mut o_types = vec ! [ ] ;
975
- let mut kernels = vec ! [ ] ;
976
- for num in 0 ..9 {
977
- let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
978
- if let Some ( kernel) = kernel{
979
- o_types. push ( gen_define_handling ( & cx, kernel, offload_entry_ty, num) ) ;
980
- kernels. push ( kernel) ;
981
- }
982
- }
983
- dbg ! ( "gen_call_handling" ) ;
984
- gen_call_handling ( & cx, & kernels, at_one, begin, update, end, fn_ty, & o_types) ;
985
- } else {
986
- dbg ! ( "no marker found" ) ;
987
- }
681
+ crate :: builder:: gpu_offload:: handle_gpu_code ( & cx) ;
988
682
}
989
683
990
684
if cfg ! ( llvm_enzyme) && enable_ad && !thin {
0 commit comments