@@ -28,7 +28,7 @@ use llvm::Linkage::*;
28
28
use crate :: back:: write:: {
29
29
self , CodegenDiagnosticsStage , DiagnosticHandlers , bitcode_section_name, save_temp_bitcode,
30
30
} ;
31
- use crate :: builder:: SBuilder ;
31
+ use crate :: builder:: { SBuilder , UNNAMED } ;
32
32
use crate :: errors:: {
33
33
DynamicLinkingWithLTO , LlvmError , LtoBitcodeFromRlib , LtoDisallowed , LtoDylib , LtoProcMacro ,
34
34
} ;
@@ -806,6 +806,27 @@ fn gen_define_handling<'ll>(cx: &'ll SimpleCx<'_>, kernel: &'ll llvm::Value, off
806
806
o_types
807
807
}
808
808
809
+
810
+ // For each kernel *call*, we now use some of our previous declared globals to move data to and from
811
+ // the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
812
+ // from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
813
+ // If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
814
+ // Since in our frontend users (by default) don't have to specify data transfer, this is something
815
+ // we should optimize in the future! We also assume that everything should be copied back and forth,
816
+ // but sometimes we can directly zero-allocate on the device and only move back, or if something is
817
+ // immutable, we might only copy it to the device, but not back.
818
+ //
819
+ // Current steps:
820
+ // 0. Alloca some variables for the following steps
821
+ // 1. set insert point before kernel call.
822
+ // 2. generate all the GEPS and stores, to be used in 3)
823
+ // 3. generate __tgt_target_data_begin calls to move data to the GPU
824
+ //
825
+ // unchanged: keep kernel call. Later move the kernel to the GPU
826
+ //
827
+ // 4. set insert point after kernel call.
828
+ // 5. generate all the GEPS and stores, to be used in 6)
829
+ // 6. generate __tgt_target_data_end calls to move data from the GPU
809
830
fn gen_call_handling < ' ll > ( cx : & ' ll SimpleCx < ' _ > , kernels : & [ & ' ll llvm:: Value ] , s_ident_t : & ' ll llvm:: Value , begin : & ' ll llvm:: Value , update : & ' ll llvm:: Value , end : & ' ll llvm:: Value , fn_ty : & ' ll llvm:: Type , o_types : & [ & ' ll llvm:: Value ] ) {
810
831
811
832
let main_fn = cx. get_function ( "main" ) ;
@@ -819,30 +840,39 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
819
840
return ;
820
841
} ;
821
842
let kernel_call_bb = unsafe { llvm:: LLVMGetInstructionParent ( kernel_call) } ;
843
+ let called = unsafe { llvm:: LLVMGetCalledValue ( kernel_call) } ;
822
844
let mut builder = SBuilder :: build ( cx, kernel_call_bb) ;
823
845
824
- let types = cx. func_params_types ( cx. get_type_of_global ( kernels [ 0 ] ) ) ;
846
+ let types = cx. func_params_types ( cx. get_type_of_global ( called ) ) ;
825
847
dbg ! ( & types) ;
826
848
let num_args = types. len ( ) as u64 ;
849
+ let mut names: Vec < & llvm:: Value > = Vec :: with_capacity ( num_args) ;
827
850
828
- // First we generate a few variables used for the data mappers below.
851
+ // Step 0)
829
852
unsafe { llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , main_fn) } ;
830
853
let ty = cx. type_array ( cx. type_ptr ( ) , num_args) ;
831
-
832
854
// Baseptr are just the input pointer to the kernel, stored in a local alloca
833
855
let a1 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_baseptrs" ) ;
834
-
835
856
// Ptrs are the result of a gep into the baseptr, at least for our trivial types.
836
857
let a2 = builder. my_alloca2 ( ty, Align :: EIGHT , ".offload_ptrs" ) ;
837
-
838
858
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
839
859
let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
840
860
let a4 = builder. my_alloca2 ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
861
+ // Now we allocate once per function param, a copy to be passed to one of our maps.
862
+ for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
863
+ // Todo:
864
+ let p = llvm:: get_param ( called, index as u32 ) ;
865
+ let name = llvm:: get_value_name ( p) ;
866
+ let arg_name = format ! ( "{name}.addr" ) ;
867
+ let alloca = unsafe { llvm:: LLVMBuildAlloca ( builder. llbuilder , in_ty, arg_name) } ;
868
+ // get function arg, store it into the alloca, and read it.
869
+ }
841
870
842
- // Now we generate the __tgt_target_data calls
871
+
872
+ // Step 1)
843
873
unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
844
- dbg ! ( "positioned builder, ready" ) ;
845
874
875
+ // Step 2)
846
876
let i32_0 = cx. get_const_i32 ( 0 ) ;
847
877
let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
848
878
let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
@@ -853,8 +883,8 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
853
883
let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
854
884
builder. call ( fn_ty, begin, & args, None ) ;
855
885
886
+ // Step 4)
856
887
unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
857
- dbg ! ( "re-positioned builder, ready" ) ;
858
888
859
889
let gep1 = builder. inbounds_gep ( ty, a1, & [ i32_0, i32_0] ) ;
860
890
let gep2 = builder. inbounds_gep ( ty, a2, & [ i32_0, i32_0] ) ;
@@ -865,15 +895,6 @@ fn gen_call_handling<'ll>(cx: &'ll SimpleCx<'_>, kernels: &[&'ll llvm::Value], s
865
895
let args = vec ! [ s_ident_t, cx. get_const_i64( u64 :: MAX ) , cx. get_const_i32( num_args) , gep1, gep2, gep3, o_type, nullptr, nullptr] ;
866
896
builder. call ( fn_ty, end, & args, None ) ;
867
897
868
- // 1. set insert point before kernel call.
869
- // 2. generate all the GEPS and stores.
870
- // 3. generate __tgt_target_data calls.
871
- //
872
- // unchanged: keep kernel call.
873
- //
874
- // 4. generate all the GEPS and stores.
875
- // 5. generate __tgt_target_data calls
876
-
877
898
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
878
899
// call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
879
900
// call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
0 commit comments