Skip to content

Commit 73deda0

Browse files
authored
[Flang][OpenMP] Fix allocating arrays with size intrinisic (#225)
Attempt to address the following example from causing an assert or ICE: subroutine test(a) implicit none integer :: i real(kind=real64), dimension(:) :: a real(kind=real64), dimension(size(a, 1)) :: b !$omp target map(tofrom: b) do i = 1, 10 b(i) = i end do !$omp end target end subroutine Where we utilise a Fortran intrinsic (size) to calculate the size of allocatable arrays and then map it to device. Borrowing some of Kareem Ergawy's current work to disentangle bounds generation from the semantic/PFT information. Co-author: Kareem Ergawy : kareem.ergawy@amd.com
1 parent 25c22b7 commit 73deda0

File tree

7 files changed

+103
-46
lines changed

7 files changed

+103
-46
lines changed

flang/lib/Lower/DirectivesCommon.h

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -610,31 +610,21 @@ void createEmptyRegionBlocks(
610610
}
611611

612612
inline AddrAndBoundsInfo
613-
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
614-
fir::FirOpBuilder &builder,
615-
Fortran::lower::SymbolRef sym, mlir::Location loc) {
616-
mlir::Value symAddr = converter.getSymbolAddress(sym);
613+
getDataOperandBaseAddr(fir::FirOpBuilder &builder,
614+
mlir::Value symAddr,
615+
bool isOptional, mlir::Location loc) {
617616
mlir::Value rawInput = symAddr;
618617
if (auto declareOp =
619618
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
620619
symAddr = declareOp.getResults()[0];
621620
rawInput = declareOp.getResults()[1];
622621
}
623622

624-
// TODO: Might need revisiting to handle for non-shared clauses
625-
if (!symAddr) {
626-
if (const auto *details =
627-
sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
628-
symAddr = converter.getSymbolAddress(details->symbol());
629-
rawInput = symAddr;
630-
}
631-
}
632-
633623
if (!symAddr)
634624
llvm::report_fatal_error("could not retrieve symbol address");
635625

636626
mlir::Value isPresent;
637-
if (Fortran::semantics::IsOptional(sym))
627+
if (isOptional)
638628
isPresent =
639629
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
640630

@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
648638
// all address/dimension retrievals. For Fortran optional though, leave
649639
// the load generation for later so it can be done in the appropriate
650640
// if branches.
651-
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
652-
!Fortran::semantics::IsOptional(sym)) {
641+
if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
653642
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
654643
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
655644
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
659648
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
660649
}
661650

651+
inline AddrAndBoundsInfo
652+
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
653+
fir::FirOpBuilder &builder,
654+
Fortran::lower::SymbolRef sym, mlir::Location loc) {
655+
return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
656+
Fortran::semantics::IsOptional(sym), loc);
657+
}
658+
662659
template <typename BoundsOp, typename BoundsType>
663660
llvm::SmallVector<mlir::Value>
664661
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1227,6 +1224,26 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
12271224

12281225
return info;
12291226
}
1227+
1228+
template <typename BoundsOp, typename BoundsType>
1229+
llvm::SmallVector<mlir::Value>
1230+
genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
1231+
fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
1232+
mlir::Location loc) {
1233+
llvm::SmallVector<mlir::Value> bounds;
1234+
1235+
mlir::Value baseOp = info.rawInput;
1236+
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
1237+
bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
1238+
dataExv, info);
1239+
if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
1240+
bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
1241+
builder, loc, dataExv, dataExvIsAssumedSize);
1242+
}
1243+
1244+
return bounds;
1245+
}
1246+
12301247
} // namespace lower
12311248
} // namespace Fortran
12321249

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,33 +1314,69 @@ static void genBodyOfTargetOp(
13141314
for (mlir::Value val : valuesDefinedAbove) {
13151315
mlir::Operation *valOp = val.getDefiningOp();
13161316
assert(valOp != nullptr);
1317-
if (mlir::isMemoryEffectFree(valOp)) {
1317+
1318+
// NOTE: We skip BoxDimsOp's as the lesser of two evils is to map the
1319+
// indices separately, as the alternative is to eventually map the Box,
1320+
// which comes with a fairly large overhead comparatively. We could be
1321+
// more robust about this and check using a BackWardsSlice to see if we
1322+
// run the risk of mapping a box.
1323+
if (mlir::isMemoryEffectFree(valOp) &&
1324+
!mlir::isa<fir::BoxDimsOp>(valOp)) {
13181325
mlir::Operation *clonedOp = valOp->clone();
13191326
entryBlock->push_front(clonedOp);
1320-
assert(clonedOp->getNumResults() == 1);
1321-
val.replaceUsesWithIf(clonedOp->getResult(0),
1322-
[entryBlock](mlir::OpOperand &use) {
1323-
return use.getOwner()->getBlock() == entryBlock;
1324-
});
1327+
1328+
auto replace = [entryBlock](mlir::OpOperand &use) {
1329+
return use.getOwner()->getBlock() == entryBlock;
1330+
};
1331+
1332+
valOp->getResults().replaceUsesWithIf(clonedOp->getResults(), replace);
1333+
valOp->replaceUsesWithIf(clonedOp, replace);
13251334
} else {
13261335
auto savedIP = firOpBuilder.getInsertionPoint();
13271336
firOpBuilder.setInsertionPointAfter(valOp);
13281337
auto copyVal =
13291338
firOpBuilder.createTemporary(val.getLoc(), val.getType());
13301339
firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
13311340

1332-
llvm::SmallVector<mlir::Value> bounds;
1341+
lower::AddrAndBoundsInfo info = lower::getDataOperandBaseAddr(
1342+
firOpBuilder, val, /*isOptional=*/false, val.getLoc());
1343+
llvm::SmallVector<mlir::Value> bounds =
1344+
Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
1345+
mlir::omp::MapBoundsType>(
1346+
firOpBuilder, info,
1347+
hlfir::translateToExtendedValue(val.getLoc(), firOpBuilder,
1348+
hlfir::Entity{val})
1349+
.first,
1350+
/*dataExvIsAssumedSize=*/false, val.getLoc());
1351+
13331352
std::stringstream name;
13341353
firOpBuilder.setInsertionPoint(targetOp);
1354+
1355+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
1356+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1357+
mlir::omp::VariableCaptureKind captureKind =
1358+
mlir::omp::VariableCaptureKind::ByRef;
1359+
1360+
mlir::Type eleType = copyVal.getType();
1361+
if (auto refType =
1362+
mlir::dyn_cast<fir::ReferenceType>(copyVal.getType()))
1363+
eleType = refType.getElementType();
1364+
1365+
if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
1366+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
1367+
} else if (!fir::isa_builtin_cptr_type(eleType)) {
1368+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1369+
}
1370+
13351371
mlir::Value mapOp = createMapInfoOp(
13361372
firOpBuilder, copyVal.getLoc(), copyVal,
13371373
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
13381374
/*members=*/llvm::SmallVector<mlir::Value>{},
13391375
/*membersIndex=*/mlir::ArrayAttr{},
13401376
static_cast<
13411377
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1342-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
1343-
mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
1378+
mapFlag),
1379+
captureKind, copyVal.getType());
13441380

13451381
// Get the index of the first non-map argument before modifying mapVars,
13461382
// then append an element to mapVars and an associated entry block
@@ -2110,7 +2146,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
21102146
dsp.processStep1();
21112147
if (enableDelayedPrivatizationStaging)
21122148
dsp.processStep2(&clauseOps);
2113-
21142149
// 5.8.1 Implicit Data-Mapping Attribute Rules
21152150
// The following code follows the implicit data-mapping rules to map all the
21162151
// symbols used inside the region that do not have explicit data-environment
@@ -2217,7 +2252,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22172252

22182253
llvm::SmallVector<mlir::Value> mapBaseValues;
22192254
extractMappedBaseValues(clauseOps.mapVars, mapBaseValues);
2220-
22212255
EntryBlockArgs args;
22222256
args.hostEvalVars = clauseOps.hostEvalVars;
22232257
// TODO: Add in_reduction syms and vars.

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,19 @@ class MapInfoFinalizationPass
158158
mlir::Value baseAddrAddr = builder.create<fir::BoxOffsetOp>(
159159
loc, descriptor, fir::BoxFieldAttr::base_addr);
160160

161+
mlir::Type underlyingVarType =
162+
llvm::cast<mlir::omp::PointerLikeType>(
163+
fir::unwrapRefType(baseAddrAddr.getType()))
164+
.getElementType();
165+
if (auto seqType = llvm::dyn_cast<fir::SequenceType>(underlyingVarType))
166+
if (seqType.hasDynamicExtents())
167+
underlyingVarType = seqType.getEleTy();
168+
161169
// Member of the descriptor pointing at the allocated data
162170
return builder.create<mlir::omp::MapInfoOp>(
163171
loc, baseAddrAddr.getType(), descriptor,
164-
mlir::TypeAttr::get(llvm::cast<mlir::omp::PointerLikeType>(
165-
fir::unwrapRefType(baseAddrAddr.getType()))
166-
.getElementType()),
167-
baseAddrAddr, /*members=*/mlir::SmallVector<mlir::Value>{},
172+
mlir::TypeAttr::get(underlyingVarType), baseAddrAddr,
173+
/*members=*/mlir::SmallVector<mlir::Value>{},
168174
/*membersIndex=*/mlir::ArrayAttr{}, bounds,
169175
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
170176
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(

flang/test/Lower/OpenMP/allocatable-array-bounds.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
!HOST: %[[BOX_3:.*]]:3 = fir.box_dims %[[LOAD_3]], %[[CONSTANT_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
2424
!HOST: %[[BOUNDS_1:.*]] = omp.map.bounds lower_bound(%[[LB_1]] : index) upper_bound(%[[UB_1]] : index) extent(%[[BOX_3]]#1 : index) stride(%[[BOX_2]]#2 : index) start_idx(%[[BOX_1]]#0 : index) {stride_in_bytes = true}
2525
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[DECLARE_1]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
26-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_1]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_1]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
26+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_1]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_1]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
2727
!HOST: %[[MAP_INFO_1:.*]] = omp.map.info var_ptr(%[[DECLARE_1]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(always, to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "sp_read(2:5)"}
2828

2929
!HOST: %[[LOAD_3:.*]] = fir.load %[[DECLARE_2]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -41,7 +41,7 @@
4141
!HOST: %[[BOX_5:.*]]:3 = fir.box_dims %[[LOAD_5]], %[[CONSTANT_5]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
4242
!HOST: %[[BOUNDS_2:.*]] = omp.map.bounds lower_bound(%[[LB_2]] : index) upper_bound(%[[UB_2]] : index) extent(%[[BOX_5]]#1 : index) stride(%[[BOX_4]]#2 : index) start_idx(%[[BOX_3]]#0 : index) {stride_in_bytes = true}
4343
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[DECLARE_2]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
44-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_2]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_2]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
44+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE_2]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS_2]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
4545
!HOST: %[[MAP_INFO_2:.*]] = omp.map.info var_ptr(%[[DECLARE_2]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(always, to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "sp_write(2:5)"}
4646

4747
subroutine read_write_section()
@@ -80,7 +80,7 @@ module assumed_allocatable_array_routines
8080
!HOST: %[[BOX_3:.*]]:3 = fir.box_dims %[[LOAD_3]], %[[CONSTANT_3]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>, index) -> (index, index, index)
8181
!HOST: %[[BOUNDS:.*]] = omp.map.bounds lower_bound(%[[LB]] : index) upper_bound(%[[UB]] : index) extent(%[[BOX_3]]#1 : index) stride(%[[BOX_2]]#2 : index) start_idx(%[[BOX_1]]#0 : index) {stride_in_bytes = true}
8282
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %[[DECLARE]]#1 base_addr : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
83-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
83+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
8484
!HOST: %[[MAP_INFO:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.box<!fir.heap<!fir.array<?xi32>>>) map_clauses(always, to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {name = "arr_read_write(2:5)"}
8585
subroutine assumed_shape_array(arr_read_write)
8686
integer, allocatable, intent(inout) :: arr_read_write(:)

flang/test/Lower/OpenMP/array-bounds.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ module assumed_array_routines
5151
!HOST: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0_DECL]]#1, %[[C0_1]] : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
5252
!HOST: %[[BOUNDS:.*]] = omp.map.bounds lower_bound(%[[C3]] : index) upper_bound(%[[C4]] : index) extent(%[[DIMS1]]#1 : index) stride(%[[DIMS0]]#2 : index) start_idx(%[[C0]] : index) {stride_in_bytes = true}
5353
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %0 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
54-
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
54+
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, i32) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
5555
!HOST: %[[MAP:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(always, to) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
5656
!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, %[[MAP_INFO_MEMBER]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
5757
subroutine assumed_shape_array(arr_read_write)

0 commit comments

Comments
 (0)