Skip to content

Commit cbc5c11

Browse files
authored
[MLIR][OpenMP] Add Lowering support for implicitly linking to default declare mappers (#131006)
1 parent 83658dd commit cbc5c11

File tree

3 files changed

+77
-11
lines changed

3 files changed

+77
-11
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -929,18 +929,35 @@ void ClauseProcessor::processMapObjects(
929929
llvm::StringRef mapperIdNameRef) const {
930930
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
931931

932+
auto getDefaultMapperID = [&](const omp::Object &object,
933+
std::string &mapperIdName) {
934+
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
935+
firOpBuilder.getRegion().getParentOp())) {
936+
const semantics::DerivedTypeSpec *typeSpec = nullptr;
937+
938+
if (object.sym()->owner().IsDerivedType())
939+
typeSpec = object.sym()->owner().derivedTypeSpec();
940+
else if (object.sym()->GetType() &&
941+
object.sym()->GetType()->category() ==
942+
semantics::DeclTypeSpec::TypeDerived)
943+
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
944+
945+
if (typeSpec) {
946+
mapperIdName = typeSpec->name().ToString() + ".default";
947+
mapperIdName =
948+
converter.mangleName(mapperIdName, *typeSpec->GetScope());
949+
}
950+
}
951+
};
952+
932953
// Create the mapper symbol from its name, if specified.
933954
mlir::FlatSymbolRefAttr mapperId;
934-
if (!mapperIdNameRef.empty() && !objects.empty()) {
955+
if (!mapperIdNameRef.empty() && !objects.empty() &&
956+
mapperIdNameRef != "__implicit_mapper") {
935957
std::string mapperIdName = mapperIdNameRef.str();
936-
if (mapperIdName == "default") {
937-
const omp::Object &object = objects.front();
938-
auto &typeSpec = object.sym()->owner().IsDerivedType()
939-
? *object.sym()->owner().derivedTypeSpec()
940-
: object.sym()->GetType()->derivedTypeSpec();
941-
mapperIdName = typeSpec.name().ToString() + ".default";
942-
mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope());
943-
}
958+
const omp::Object &object = objects.front();
959+
if (mapperIdNameRef == "default")
960+
getDefaultMapperID(object, mapperIdName);
944961
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
945962
"mapper not found");
946963
mapperId =
@@ -978,6 +995,15 @@ void ClauseProcessor::processMapObjects(
978995
}
979996
}
980997

998+
if (mapperIdNameRef == "__implicit_mapper") {
999+
std::string mapperIdName;
1000+
getDefaultMapperID(object, mapperIdName);
1001+
mapperId = converter.getModuleOp().lookupSymbol(mapperIdName)
1002+
? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1003+
mapperIdName)
1004+
: mlir::FlatSymbolRefAttr();
1005+
}
1006+
9811007
// Explicit map captures are captured ByRef by default,
9821008
// optimisation passes may alter this to ByCopy or other capture
9831009
// types to optimise
@@ -1023,7 +1049,7 @@ bool ClauseProcessor::processMap(
10231049
const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t;
10241050
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
10251051
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1026-
std::string mapperIdName;
1052+
std::string mapperIdName = "__implicit_mapper";
10271053
// If the map type is specified, then process it else Tofrom is the
10281054
// default.
10291055
Map::MapType type = mapType.value_or(Map::MapType::Tofrom);

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,16 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22492249
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
22502250
name << sym.name().ToString();
22512251

2252+
mlir::FlatSymbolRefAttr mapperId;
2253+
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
2254+
auto &typeSpec = sym.GetType()->derivedTypeSpec();
2255+
std::string mapperIdName = typeSpec.name().ToString() + ".default";
2256+
mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope());
2257+
if (converter.getModuleOp().lookupSymbol(mapperIdName))
2258+
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2259+
mapperIdName);
2260+
}
2261+
22522262
fir::factory::AddrAndBoundsInfo info =
22532263
Fortran::lower::getDataOperandBaseAddr(
22542264
converter, firOpBuilder, sym, converter.getCurrentLocation());
@@ -2307,7 +2317,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23072317
static_cast<
23082318
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
23092319
mapFlag),
2310-
captureKind, baseOp.getType());
2320+
captureKind, baseOp.getType(), /*partialMap=*/false, mapperId);
23112321

23122322
clauseOps.mapVars.push_back(mapOp);
23132323
mapSyms.push_back(&sym);

flang/test/Lower/OpenMP/declare-mapper.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90
55
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90
66
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
7+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
78

89
!--- omp-declare-mapper-1.f90
910
subroutine declare_mapper_1
@@ -141,3 +142,32 @@ subroutine declare_mapper_3
141142
!$omp declare mapper (my_mapper : my_type :: var) map (var, var%values (1:var%num_vals))
142143
!$omp declare mapper (my_mapper2 : my_type2 :: v) map (mapper(my_mapper) : v%my_type_var) map (tofrom : v%arr)
143144
end subroutine declare_mapper_3
145+
146+
!--- omp-declare-mapper-4.f90
147+
subroutine declare_mapper_4
148+
type my_type
149+
integer :: num
150+
end type
151+
152+
!CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER:_QQFdeclare_mapper_4my_type.default]] : [[MY_TYPE:!fir\.type<_QFdeclare_mapper_4Tmy_type\{num:i32\}>]]
153+
!$omp declare mapper (my_type :: var) map (var%num)
154+
155+
type(my_type) :: a
156+
integer :: b
157+
!CHECK: %{{.*}} = omp.map.info var_ptr(%{{.*}}#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) mapper(@[[MY_TYPE_MAPPER]]) map_clauses(tofrom) capture(ByRef) -> !fir.ref<[[MY_TYPE]]> {name = "a"}
158+
!CHECK: %{{.*}} = omp.map.info var_ptr(%{{.*}}#1 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "b"}
159+
!$omp target map(a, b)
160+
a%num = 10
161+
b = 20
162+
!$omp end target
163+
164+
!CHECK: %{{.*}} = omp.map.info var_ptr(%{{.*}} : !fir.ref<i32>, i32) mapper(@[[MY_TYPE_MAPPER]]) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a%{{.*}}"}
165+
!$omp target map(a%num)
166+
a%num = 30
167+
!$omp end target
168+
169+
!CHECK: %{{.*}} = omp.map.info var_ptr(%{{.*}}#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) mapper(@[[MY_TYPE_MAPPER]]) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<[[MY_TYPE]]> {name = "a"}
170+
!$omp target
171+
a%num = 40
172+
!$omp end target
173+
end subroutine declare_mapper_4

0 commit comments

Comments
 (0)