Skip to content

Commit 74b5144

Browse files
committed
[OpenMP][Flang] Emit default declare mappers implicitly for derived types
This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types. This supports nested derived types as well.
1 parent 1c305f7 commit 74b5144

File tree

4 files changed

+174
-20
lines changed

4 files changed

+174
-20
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,24 +1167,28 @@ void ClauseProcessor::processMapObjects(
11671167

11681168
auto getDefaultMapperID = [&](const omp::Object &object,
11691169
std::string &mapperIdName) {
1170-
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
1171-
firOpBuilder.getRegion().getParentOp())) {
1172-
const semantics::DerivedTypeSpec *typeSpec = nullptr;
1173-
1174-
if (object.sym()->owner().IsDerivedType())
1175-
typeSpec = object.sym()->owner().derivedTypeSpec();
1176-
else if (object.sym()->GetType() &&
1177-
object.sym()->GetType()->category() ==
1178-
semantics::DeclTypeSpec::TypeDerived)
1179-
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
1180-
1181-
if (typeSpec) {
1182-
mapperIdName =
1183-
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1184-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1185-
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1186-
}
1170+
const semantics::DerivedTypeSpec *typeSpec = nullptr;
1171+
1172+
if (object.sym()->GetType() && object.sym()->GetType()->category() ==
1173+
semantics::DeclTypeSpec::TypeDerived)
1174+
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
1175+
else if (object.sym()->owner().IsDerivedType())
1176+
typeSpec = object.sym()->owner().derivedTypeSpec();
1177+
1178+
if (typeSpec) {
1179+
mapperIdName =
1180+
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1181+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1182+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
11871183
}
1184+
1185+
// Make sure we don't return a mapper to self
1186+
llvm::StringRef parentOpName;
1187+
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
1188+
firOpBuilder.getRegion().getParentOp()))
1189+
parentOpName = declMapOp.getSymName();
1190+
if (mapperIdName == parentOpName)
1191+
mapperIdName = "";
11881192
};
11891193

11901194
// Create the mapper symbol from its name, if specified.

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2391,6 +2391,124 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23912391
queue, item, clauseOps);
23922392
}
23932393

2394+
static mlir::FlatSymbolRefAttr
2395+
genImplicitDefaultDeclareMapper(lower::AbstractConverter &converter,
2396+
mlir::Location loc, fir::RecordType recordType,
2397+
llvm::StringRef mapperNameStr) {
2398+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2399+
lower::StatementContext stmtCtx;
2400+
2401+
// Save current insertion point before moving to the module scope to create
2402+
// the DeclareMapperOp
2403+
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
2404+
2405+
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
2406+
auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
2407+
loc, mapperNameStr, recordType);
2408+
auto &region = declMapperOp.getRegion();
2409+
firOpBuilder.createBlock(&region);
2410+
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
2411+
2412+
auto declareOp =
2413+
firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
2414+
2415+
const auto genBoundsOps = [&](mlir::Value mapVal,
2416+
llvm::SmallVectorImpl<mlir::Value> &bounds) {
2417+
fir::ExtendedValue extVal =
2418+
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
2419+
hlfir::Entity{mapVal},
2420+
/*contiguousHint=*/true)
2421+
.first;
2422+
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
2423+
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
2424+
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
2425+
mlir::omp::MapBoundsType>(
2426+
firOpBuilder, info, extVal,
2427+
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
2428+
};
2429+
2430+
// Return a reference to the contents of a derived type with one field.
2431+
// Also return the field type.
2432+
const auto getFieldRef =
2433+
[&](mlir::Value rec,
2434+
unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
2435+
auto recType = mlir::dyn_cast<fir::RecordType>(
2436+
fir::unwrapPassByRefType(rec.getType()));
2437+
auto [fieldName, fieldTy] = recType.getTypeList()[index];
2438+
mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
2439+
loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
2440+
fir::getTypeParams(rec));
2441+
return {firOpBuilder.create<fir::CoordinateOp>(
2442+
loc, firOpBuilder.getRefType(fieldTy), rec, field),
2443+
fieldTy};
2444+
};
2445+
2446+
mlir::omp::DeclareMapperInfoOperands clauseOps;
2447+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
2448+
llvm::SmallVector<mlir::Value> memberMapOps;
2449+
2450+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2451+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
2452+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
2453+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
2454+
mlir::omp::VariableCaptureKind captureKind =
2455+
mlir::omp::VariableCaptureKind::ByRef;
2456+
int64_t index = 0;
2457+
2458+
// Populate the declareMapper region with the map information.
2459+
for (const auto &[memberName, memberType] :
2460+
mlir::dyn_cast<fir::RecordType>(recordType).getTypeList()) {
2461+
auto [ref, type] = getFieldRef(declareOp.getBase(), index);
2462+
mlir::FlatSymbolRefAttr mapperId;
2463+
if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
2464+
std::string mapperIdName =
2465+
recType.getName().str() + ".omp.default.mapper";
2466+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
2467+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2468+
else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
2469+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2470+
2471+
if (converter.getModuleOp().lookupSymbol(mapperIdName))
2472+
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2473+
mapperIdName);
2474+
else
2475+
mapperId = genImplicitDefaultDeclareMapper(converter, loc, recType,
2476+
mapperIdName);
2477+
}
2478+
2479+
llvm::SmallVector<mlir::Value> bounds;
2480+
genBoundsOps(ref, bounds);
2481+
mlir::Value mapOp = createMapInfoOp(
2482+
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
2483+
/*members=*/{},
2484+
/*membersIndex=*/mlir::ArrayAttr{},
2485+
static_cast<
2486+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2487+
mapFlag),
2488+
captureKind, ref.getType(), /*partialMap=*/false, mapperId);
2489+
memberMapOps.emplace_back(mapOp);
2490+
memberPlacementIndices.emplace_back(llvm::SmallVector<int64_t>{index++});
2491+
}
2492+
2493+
llvm::SmallVector<mlir::Value> bounds;
2494+
genBoundsOps(declareOp.getOriginalBase(), bounds);
2495+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
2496+
firOpBuilder, loc, declareOp.getOriginalBase(),
2497+
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
2498+
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices),
2499+
static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2500+
mapFlag),
2501+
captureKind, declareOp.getType(0),
2502+
/*partialMap=*/true);
2503+
2504+
clauseOps.mapVars.emplace_back(mapOp);
2505+
2506+
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
2507+
// declMapperOp->dumpPretty();
2508+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2509+
mapperNameStr);
2510+
}
2511+
23942512
static mlir::omp::TargetOp
23952513
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23962514
lower::StatementContext &stmtCtx,
@@ -2467,15 +2585,26 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24672585
name << sym.name().ToString();
24682586

24692587
mlir::FlatSymbolRefAttr mapperId;
2470-
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
2588+
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived &&
2589+
defaultMaps.empty()) {
24712590
auto &typeSpec = sym.GetType()->derivedTypeSpec();
24722591
std::string mapperIdName =
24732592
typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName;
24742593
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
24752594
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2595+
else
2596+
mapperIdName =
2597+
converter.mangleName(mapperIdName, *typeSpec.GetScope());
2598+
24762599
if (converter.getModuleOp().lookupSymbol(mapperIdName))
24772600
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
24782601
mapperIdName);
2602+
else
2603+
mapperId = genImplicitDefaultDeclareMapper(
2604+
converter, loc,
2605+
mlir::cast<fir::RecordType>(
2606+
converter.genType(sym.GetType()->derivedTypeSpec())),
2607+
mapperIdName);
24792608
}
24802609

24812610
fir::factory::AddrAndBoundsInfo info =
@@ -3442,6 +3571,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
34423571
ClauseProcessor cp(converter, semaCtx, clauses);
34433572
cp.processMap(loc, stmtCtx, clauseOps);
34443573
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
3574+
// declMapperOp->dumpPretty();
34453575
}
34463576

34473577
static void

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ class MapInfoFinalizationPass
433433
getDescriptorMapType(mapType, target)),
434434
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
435435
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
436-
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
436+
op.getMapperIdAttr(), op.getNameAttr(),
437437
/*partial_map=*/builder.getBoolAttr(false));
438438
op.replaceAllUsesWith(newDescParentMapOp.getResult());
439439
op->erase();

flang/test/Lower/OpenMP/derived-type-map.f90

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
22

3+
!CHECK: omp.declare_mapper @[[MAPPER1:_QQFmaptype_derived_implicit_allocatablescalar_and_array.omp.default.mapper]] : !fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {
4+
!CHECK: omp.declare_mapper @[[MAPPER2:_QQFmaptype_derived_implicitscalar_and_array.omp.default.mapper]] : !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {
35

46
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicitEscalar_arr"}
57
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_implicitEscalar_arr"} : (!fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) -> (!fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>)
6-
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}
8+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) mapper(@[[MAPPER2]]) -> !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}
79
!CHECK: omp.target map_entries(%[[MAP]] -> %[[ARG0:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) {
810
subroutine mapType_derived_implicit
911
type :: scalar_and_array
@@ -18,6 +20,24 @@ subroutine mapType_derived_implicit
1820
!$omp end target
1921
end subroutine mapType_derived_implicit
2022

23+
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"}
24+
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"} : {{.*}}
25+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : {{.*}}) map_clauses(implicit, to) capture(ByRef) mapper(@[[MAPPER1]])
26+
!CHECK: omp.target map_entries(%[[MAP]] -> %[[ARG0:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>>>, !fir.llvm_ptr<!fir.ref<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>>) {
27+
subroutine mapType_derived_implicit_allocatable
28+
type :: scalar_and_array
29+
real(4) :: real
30+
integer(4) :: array(10)
31+
integer(4) :: int
32+
end type scalar_and_array
33+
type(scalar_and_array), allocatable :: scalar_arr
34+
35+
allocate (scalar_arr)
36+
!$omp target
37+
scalar_arr%int = 1
38+
!$omp end target
39+
end subroutine mapType_derived_implicit_allocatable
40+
2141
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_explicitEscalar_arr"}
2242
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_explicitEscalar_arr"} : (!fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) -> (!fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>)
2343
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}

0 commit comments

Comments
 (0)