Skip to content

Commit 4d7df40

Browse files
authored
[flang][cuda] Materialize constant src in memory (#116851)
When the src of the data transfer is a constant, it needs to be materialized in memory to be able to perform a data transfer. ``` subroutine sub1() real, device :: a(10) integer :: I do i = 5, 10 a(i) = -4.0 end do end ```
1 parent 0765136 commit 4d7df40

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,12 @@ struct CUFDataTransferOpConversion
628628

629629
mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
630630
mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
631+
// Materialize the src if constant.
632+
if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
633+
mlir::Value temp = builder.createTemporary(loc, srcTy);
634+
builder.create<fir::StoreOp>(loc, src, temp);
635+
src = temp;
636+
}
631637
llvm::SmallVector<mlir::Value> args{
632638
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
633639
modeValue, sourceFile, sourceLine)};

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,4 +513,44 @@ func.func @_QPcallkernel(%arg0: !fir.box<!fir.array<?x?xcomplex<f32>>> {fir.bind
513513
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[ALLOCA]] : (!fir.ref<!fir.box<!fir.array<?x?xcomplex<f32>>>>) -> !fir.ref<!fir.box<none>>
514514
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
515515

516+
func.func @_QPsrc_cst() {
517+
%0 = fir.dummy_scope : !fir.dscope
518+
%1 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "d4", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Ed4"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
519+
%5:2 = hlfir.declare %1 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Ed4"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
520+
%6 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsub4Ei"}
521+
%7:2 = hlfir.declare %6 {uniq_name = "_QFsub4Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
522+
%c1 = arith.constant 1 : index
523+
%c10_i32 = arith.constant 10 : i32
524+
%c0_i32 = arith.constant 0 : i32
525+
%9 = fir.convert %5#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
526+
%c6_i32 = arith.constant 6 : i32
527+
%14 = fir.convert %c6_i32 : (i32) -> index
528+
%c10_i32_0 = arith.constant 10 : i32
529+
%15 = fir.convert %c10_i32_0 : (i32) -> index
530+
%c1_1 = arith.constant 1 : index
531+
%16 = fir.convert %14 : (index) -> i32
532+
%17:2 = fir.do_loop %arg1 = %14 to %15 step %c1_1 iter_args(%arg2 = %16) -> (index, i32) {
533+
fir.store %arg2 to %7#1 : !fir.ref<i32>
534+
%cst = arith.constant -4.000000e+00 : f32
535+
%22 = fir.load %5#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
536+
%23 = fir.load %7#0 : !fir.ref<i32>
537+
%24 = fir.convert %23 : (i32) -> i64
538+
%25 = hlfir.designate %22 (%24) : (!fir.box<!fir.heap<!fir.array<?xf32>>>, i64) -> !fir.ref<f32>
539+
cuf.data_transfer %cst to %25 {transfer_kind = #cuf.cuda_transfer<host_device>} : f32, !fir.ref<f32>
540+
%26 = arith.addi %arg1, %c1_1 : index
541+
%27 = fir.convert %c1_1 : (index) -> i32
542+
%28 = fir.load %7#1 : !fir.ref<i32>
543+
%29 = arith.addi %28, %27 : i32
544+
fir.result %26, %29 : index, i32
545+
}
546+
return
547+
}
548+
549+
// CHECK-LABEL: func.func @_QPsrc_cst()
550+
// CHECK: %[[ALLOCA:.*]] = fir.alloca f32
551+
// CHECK: %[[CST:.*]] = arith.constant -4.000000e+00 : f32
552+
// CHECK: fir.store %[[CST]] to %[[ALLOCA]] : !fir.ref<f32>
553+
// CHECK: %[[CONV:.*]] = fir.convert %[[ALLOCA]] : (!fir.ref<f32>) -> !fir.llvm_ptr<i8>
554+
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%{{.*}}, %[[CONV]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
555+
516556
} // end of module

0 commit comments

Comments
 (0)