Skip to content

Commit c0bde07

Browse files
abhinavgabagithub-actions[bot]
authored andcommitted
Automerge: [NFC][Clang][OpenMP] Refactor mapinfo generation for captured vars (#146891)
The refactored code would allow creating multiple member-of maps for the same captured var, which would be useful for changes like llvm/llvm-project#145454.
2 parents 356a552 + 02f60fd commit c0bde07

File tree

1 file changed

+86
-42
lines changed

1 file changed

+86
-42
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 86 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6801,6 +6801,11 @@ class MappableExprsHandler {
68016801
llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
68026802
using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
68036803
using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;
6804+
using MapData =
6805+
std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
6806+
OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>,
6807+
bool /*IsImplicit*/, const ValueDecl *, const Expr *>;
6808+
using MapDataArrayTy = SmallVector<MapData, 4>;
68046809

68056810
/// This structure contains combined information generated for mappable
68066811
/// clauses, including base pointers, pointers, sizes, map types, user-defined
@@ -8496,6 +8501,7 @@ class MappableExprsHandler {
84968501
const StructRangeInfoTy &PartialStruct, bool IsMapThis,
84978502
llvm::OpenMPIRBuilder &OMPBuilder,
84988503
const ValueDecl *VD = nullptr,
8504+
unsigned OffsetForMemberOfFlag = 0,
84998505
bool NotTargetParams = true) const {
85008506
if (CurTypes.size() == 1 &&
85018507
((CurTypes.back() & OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
@@ -8583,8 +8589,8 @@ class MappableExprsHandler {
85838589
// All other current entries will be MEMBER_OF the combined entry
85848590
// (except for PTR_AND_OBJ entries which do not have a placeholder value
85858591
// 0xFFFF in the MEMBER_OF field).
8586-
OpenMPOffloadMappingFlags MemberOfFlag =
8587-
OMPBuilder.getMemberOfFlag(CombinedInfo.BasePointers.size() - 1);
8592+
OpenMPOffloadMappingFlags MemberOfFlag = OMPBuilder.getMemberOfFlag(
8593+
OffsetForMemberOfFlag + CombinedInfo.BasePointers.size() - 1);
85888594
for (auto &M : CurTypes)
85898595
OMPBuilder.setCorrectMemberOfFlag(M, MemberOfFlag);
85908596
}
@@ -8727,11 +8733,13 @@ class MappableExprsHandler {
87278733
}
87288734
}
87298735

8730-
/// Generate the base pointers, section pointers, sizes, map types, and
8731-
/// mappers associated to a given capture (all included in \a CombinedInfo).
8732-
void generateInfoForCapture(const CapturedStmt::Capture *Cap,
8733-
llvm::Value *Arg, MapCombinedInfoTy &CombinedInfo,
8734-
StructRangeInfoTy &PartialStruct) const {
8736+
/// For a capture that has an associated clause, generate the base pointers,
8737+
/// section pointers, sizes, map types, and mappers (all included in
8738+
/// \a CurCaptureVarInfo).
8739+
void generateInfoForCaptureFromClauseInfo(
8740+
const CapturedStmt::Capture *Cap, llvm::Value *Arg,
8741+
MapCombinedInfoTy &CurCaptureVarInfo, llvm::OpenMPIRBuilder &OMPBuilder,
8742+
unsigned OffsetForMemberOfFlag) const {
87358743
assert(!Cap->capturesVariableArrayType() &&
87368744
"Not expecting to generate map info for a variable array type!");
87378745

@@ -8749,26 +8757,22 @@ class MappableExprsHandler {
87498757
// pass the pointer by value. If it is a reference to a declaration, we just
87508758
// pass its value.
87518759
if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
8752-
CombinedInfo.Exprs.push_back(VD);
8753-
CombinedInfo.BasePointers.emplace_back(Arg);
8754-
CombinedInfo.DevicePtrDecls.emplace_back(VD);
8755-
CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
8756-
CombinedInfo.Pointers.push_back(Arg);
8757-
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
8760+
CurCaptureVarInfo.Exprs.push_back(VD);
8761+
CurCaptureVarInfo.BasePointers.emplace_back(Arg);
8762+
CurCaptureVarInfo.DevicePtrDecls.emplace_back(VD);
8763+
CurCaptureVarInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
8764+
CurCaptureVarInfo.Pointers.push_back(Arg);
8765+
CurCaptureVarInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
87588766
CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
87598767
/*isSigned=*/true));
8760-
CombinedInfo.Types.push_back(
8768+
CurCaptureVarInfo.Types.push_back(
87618769
OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
87628770
OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM);
8763-
CombinedInfo.Mappers.push_back(nullptr);
8771+
CurCaptureVarInfo.Mappers.push_back(nullptr);
87648772
return;
87658773
}
87668774

8767-
using MapData =
8768-
std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
8769-
OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>, bool,
8770-
const ValueDecl *, const Expr *>;
8771-
SmallVector<MapData, 4> DeclComponentLists;
8775+
MapDataArrayTy DeclComponentLists;
87728776
// For member fields list in is_device_ptr, store it in
87738777
// DeclComponentLists for generating components info.
87748778
static const OpenMPMapModifierKind Unknown = OMPC_MAP_MODIFIER_unknown;
@@ -8826,6 +8830,51 @@ class MappableExprsHandler {
88268830
return (HasPresent && !HasPresentR) || (HasAllocs && !HasAllocsR);
88278831
});
88288832

8833+
auto GenerateInfoForComponentLists =
8834+
[&](ArrayRef<MapData> DeclComponentLists,
8835+
bool IsEligibleForTargetParamFlag) {
8836+
MapCombinedInfoTy CurInfoForComponentLists;
8837+
StructRangeInfoTy PartialStruct;
8838+
8839+
if (DeclComponentLists.empty())
8840+
return;
8841+
8842+
generateInfoForCaptureFromComponentLists(
8843+
VD, DeclComponentLists, CurInfoForComponentLists, PartialStruct,
8844+
IsEligibleForTargetParamFlag,
8845+
/*AreBothBasePtrAndPteeMapped=*/HasMapBasePtr && HasMapArraySec);
8846+
8847+
// If there is an entry in PartialStruct it means we have a
8848+
// struct with individual members mapped. Emit an extra combined
8849+
// entry.
8850+
if (PartialStruct.Base.isValid()) {
8851+
CurCaptureVarInfo.append(PartialStruct.PreliminaryMapData);
8852+
emitCombinedEntry(
8853+
CurCaptureVarInfo, CurInfoForComponentLists.Types,
8854+
PartialStruct, Cap->capturesThis(), OMPBuilder, nullptr,
8855+
OffsetForMemberOfFlag,
8856+
/*NotTargetParams*/ !IsEligibleForTargetParamFlag);
8857+
}
8858+
8859+
// Return if we didn't add any entries.
8860+
if (CurInfoForComponentLists.BasePointers.empty())
8861+
return;
8862+
8863+
CurCaptureVarInfo.append(CurInfoForComponentLists);
8864+
};
8865+
8866+
GenerateInfoForComponentLists(DeclComponentLists,
8867+
/*IsEligibleForTargetParamFlag=*/true);
8868+
}
8869+
8870+
/// Generate the base pointers, section pointers, sizes, map types, and
8871+
/// mappers associated to \a DeclComponentLists for a given capture
8872+
/// \a VD (all included in \a CurComponentListInfo).
8873+
void generateInfoForCaptureFromComponentLists(
8874+
const ValueDecl *VD, ArrayRef<MapData> DeclComponentLists,
8875+
MapCombinedInfoTy &CurComponentListInfo, StructRangeInfoTy &PartialStruct,
8876+
bool IsListEligibleForTargetParamFlag,
8877+
bool AreBothBasePtrAndPteeMapped = false) const {
88298878
// Find overlapping elements (including the offset from the base element).
88308879
llvm::SmallDenseMap<
88318880
const MapData *,
@@ -8949,7 +8998,7 @@ class MappableExprsHandler {
89498998

89508999
// Associated with a capture, because the mapping flags depend on it.
89519000
// Go through all of the elements with the overlapped elements.
8952-
bool IsFirstComponentList = true;
9001+
bool AddTargetParamFlag = IsListEligibleForTargetParamFlag;
89539002
MapCombinedInfoTy StructBaseCombinedInfo;
89549003
for (const auto &Pair : OverlappedData) {
89559004
const MapData &L = *Pair.getFirst();
@@ -8964,11 +9013,11 @@ class MappableExprsHandler {
89649013
ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
89659014
OverlappedComponents = Pair.getSecond();
89669015
generateInfoForComponentList(
8967-
MapType, MapModifiers, {}, Components, CombinedInfo,
8968-
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
8969-
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
9016+
MapType, MapModifiers, {}, Components, CurComponentListInfo,
9017+
StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag, IsImplicit,
9018+
/*GenerateAllInfoForClauses*/ false, Mapper,
89709019
/*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents);
8971-
IsFirstComponentList = false;
9020+
AddTargetParamFlag = false;
89729021
}
89739022
// Go through other elements without overlapped elements.
89749023
for (const MapData &L : DeclComponentLists) {
@@ -8983,12 +9032,12 @@ class MappableExprsHandler {
89839032
auto It = OverlappedData.find(&L);
89849033
if (It == OverlappedData.end())
89859034
generateInfoForComponentList(
8986-
MapType, MapModifiers, {}, Components, CombinedInfo,
8987-
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
9035+
MapType, MapModifiers, {}, Components, CurComponentListInfo,
9036+
StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag,
89889037
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
89899038
/*ForDeviceAddr=*/false, VD, VarRef,
8990-
/*OverlappedElements*/ {}, HasMapBasePtr && HasMapArraySec);
8991-
IsFirstComponentList = false;
9039+
/*OverlappedElements*/ {}, AreBothBasePtrAndPteeMapped);
9040+
AddTargetParamFlag = false;
89929041
}
89939042
}
89949043

@@ -9467,7 +9516,6 @@ static void genMapInfoForCaptures(
94679516
CE = CS.capture_end();
94689517
CI != CE; ++CI, ++RI, ++CV) {
94699518
MappableExprsHandler::MapCombinedInfoTy CurInfo;
9470-
MappableExprsHandler::StructRangeInfoTy PartialStruct;
94719519

94729520
// VLA sizes are passed to the outlined region by copy and do not have map
94739521
// information associated.
@@ -9488,37 +9536,33 @@ static void genMapInfoForCaptures(
94889536
} else {
94899537
// If we have any information in the map clause, we use it, otherwise we
94909538
// just do a default mapping.
9491-
MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
9539+
MEHandler.generateInfoForCaptureFromClauseInfo(
9540+
CI, *CV, CurInfo, OMPBuilder,
9541+
/*OffsetForMemberOfFlag=*/CombinedInfo.BasePointers.size());
9542+
94929543
if (!CI->capturesThis())
94939544
MappedVarSet.insert(CI->getCapturedVar());
94949545
else
94959546
MappedVarSet.insert(nullptr);
9496-
if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
9547+
9548+
if (CurInfo.BasePointers.empty())
94979549
MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
9550+
94989551
// Generate correct mapping for variables captured by reference in
94999552
// lambdas.
95009553
if (CI->capturesVariable())
95019554
MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
95029555
CurInfo, LambdaPointers);
95039556
}
95049557
// We expect to have at least an element of information for this capture.
9505-
assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
9558+
assert(!CurInfo.BasePointers.empty() &&
95069559
"Non-existing map pointer for capture!");
95079560
assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
95089561
CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
95099562
CurInfo.BasePointers.size() == CurInfo.Types.size() &&
95109563
CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
95119564
"Inconsistent map information sizes!");
95129565

9513-
// If there is an entry in PartialStruct it means we have a struct with
9514-
// individual members mapped. Emit an extra combined entry.
9515-
if (PartialStruct.Base.isValid()) {
9516-
CombinedInfo.append(PartialStruct.PreliminaryMapData);
9517-
MEHandler.emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct,
9518-
CI->capturesThis(), OMPBuilder, nullptr,
9519-
/*NotTargetParams*/ false);
9520-
}
9521-
95229566
// We need to append the results of this capture to what we already have.
95239567
CombinedInfo.append(CurInfo);
95249568
}

0 commit comments

Comments
 (0)