@@ -6801,6 +6801,11 @@ class MappableExprsHandler {
6801
6801
llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
6802
6802
using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
6803
6803
using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;
6804
+ using MapData =
6805
+ std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
6806
+ OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>,
6807
+ bool /*IsImplicit*/, const ValueDecl *, const Expr *>;
6808
+ using MapDataArrayTy = SmallVector<MapData, 4>;
6804
6809
6805
6810
/// This structure contains combined information generated for mappable
6806
6811
/// clauses, including base pointers, pointers, sizes, map types, user-defined
@@ -8496,6 +8501,7 @@ class MappableExprsHandler {
8496
8501
const StructRangeInfoTy &PartialStruct, bool IsMapThis,
8497
8502
llvm::OpenMPIRBuilder &OMPBuilder,
8498
8503
const ValueDecl *VD = nullptr,
8504
+ unsigned OffsetForMemberOfFlag = 0,
8499
8505
bool NotTargetParams = true) const {
8500
8506
if (CurTypes.size() == 1 &&
8501
8507
((CurTypes.back() & OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
@@ -8583,8 +8589,8 @@ class MappableExprsHandler {
8583
8589
// All other current entries will be MEMBER_OF the combined entry
8584
8590
// (except for PTR_AND_OBJ entries which do not have a placeholder value
8585
8591
// 0xFFFF in the MEMBER_OF field).
8586
- OpenMPOffloadMappingFlags MemberOfFlag =
8587
- OMPBuilder.getMemberOfFlag( CombinedInfo.BasePointers.size() - 1);
8592
+ OpenMPOffloadMappingFlags MemberOfFlag = OMPBuilder.getMemberOfFlag(
8593
+ OffsetForMemberOfFlag + CombinedInfo.BasePointers.size() - 1);
8588
8594
for (auto &M : CurTypes)
8589
8595
OMPBuilder.setCorrectMemberOfFlag(M, MemberOfFlag);
8590
8596
}
@@ -8727,11 +8733,13 @@ class MappableExprsHandler {
8727
8733
}
8728
8734
}
8729
8735
8730
- /// Generate the base pointers, section pointers, sizes, map types, and
8731
- /// mappers associated to a given capture (all included in \a CombinedInfo).
8732
- void generateInfoForCapture(const CapturedStmt::Capture *Cap,
8733
- llvm::Value *Arg, MapCombinedInfoTy &CombinedInfo,
8734
- StructRangeInfoTy &PartialStruct) const {
8736
+ /// For a capture that has an associated clause, generate the base pointers,
8737
+ /// section pointers, sizes, map types, and mappers (all included in
8738
+ /// \a CurCaptureVarInfo).
8739
+ void generateInfoForCaptureFromClauseInfo(
8740
+ const CapturedStmt::Capture *Cap, llvm::Value *Arg,
8741
+ MapCombinedInfoTy &CurCaptureVarInfo, llvm::OpenMPIRBuilder &OMPBuilder,
8742
+ unsigned OffsetForMemberOfFlag) const {
8735
8743
assert(!Cap->capturesVariableArrayType() &&
8736
8744
"Not expecting to generate map info for a variable array type!");
8737
8745
@@ -8749,26 +8757,22 @@ class MappableExprsHandler {
8749
8757
// pass the pointer by value. If it is a reference to a declaration, we just
8750
8758
// pass its value.
8751
8759
if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
8752
- CombinedInfo .Exprs.push_back(VD);
8753
- CombinedInfo .BasePointers.emplace_back(Arg);
8754
- CombinedInfo .DevicePtrDecls.emplace_back(VD);
8755
- CombinedInfo .DevicePointers.emplace_back(DeviceInfoTy::Pointer);
8756
- CombinedInfo .Pointers.push_back(Arg);
8757
- CombinedInfo .Sizes.push_back(CGF.Builder.CreateIntCast(
8760
+ CurCaptureVarInfo .Exprs.push_back(VD);
8761
+ CurCaptureVarInfo .BasePointers.emplace_back(Arg);
8762
+ CurCaptureVarInfo .DevicePtrDecls.emplace_back(VD);
8763
+ CurCaptureVarInfo .DevicePointers.emplace_back(DeviceInfoTy::Pointer);
8764
+ CurCaptureVarInfo .Pointers.push_back(Arg);
8765
+ CurCaptureVarInfo .Sizes.push_back(CGF.Builder.CreateIntCast(
8758
8766
CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
8759
8767
/*isSigned=*/true));
8760
- CombinedInfo .Types.push_back(
8768
+ CurCaptureVarInfo .Types.push_back(
8761
8769
OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
8762
8770
OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM);
8763
- CombinedInfo .Mappers.push_back(nullptr);
8771
+ CurCaptureVarInfo .Mappers.push_back(nullptr);
8764
8772
return;
8765
8773
}
8766
8774
8767
- using MapData =
8768
- std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
8769
- OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>, bool,
8770
- const ValueDecl *, const Expr *>;
8771
- SmallVector<MapData, 4> DeclComponentLists;
8775
+ MapDataArrayTy DeclComponentLists;
8772
8776
// For member fields list in is_device_ptr, store it in
8773
8777
// DeclComponentLists for generating components info.
8774
8778
static const OpenMPMapModifierKind Unknown = OMPC_MAP_MODIFIER_unknown;
@@ -8826,6 +8830,51 @@ class MappableExprsHandler {
8826
8830
return (HasPresent && !HasPresentR) || (HasAllocs && !HasAllocsR);
8827
8831
});
8828
8832
8833
+ auto GenerateInfoForComponentLists =
8834
+ [&](ArrayRef<MapData> DeclComponentLists,
8835
+ bool IsEligibleForTargetParamFlag) {
8836
+ MapCombinedInfoTy CurInfoForComponentLists;
8837
+ StructRangeInfoTy PartialStruct;
8838
+
8839
+ if (DeclComponentLists.empty())
8840
+ return;
8841
+
8842
+ generateInfoForCaptureFromComponentLists(
8843
+ VD, DeclComponentLists, CurInfoForComponentLists, PartialStruct,
8844
+ IsEligibleForTargetParamFlag,
8845
+ /*AreBothBasePtrAndPteeMapped=*/HasMapBasePtr && HasMapArraySec);
8846
+
8847
+ // If there is an entry in PartialStruct it means we have a
8848
+ // struct with individual members mapped. Emit an extra combined
8849
+ // entry.
8850
+ if (PartialStruct.Base.isValid()) {
8851
+ CurCaptureVarInfo.append(PartialStruct.PreliminaryMapData);
8852
+ emitCombinedEntry(
8853
+ CurCaptureVarInfo, CurInfoForComponentLists.Types,
8854
+ PartialStruct, Cap->capturesThis(), OMPBuilder, nullptr,
8855
+ OffsetForMemberOfFlag,
8856
+ /*NotTargetParams*/ !IsEligibleForTargetParamFlag);
8857
+ }
8858
+
8859
+ // Return if we didn't add any entries.
8860
+ if (CurInfoForComponentLists.BasePointers.empty())
8861
+ return;
8862
+
8863
+ CurCaptureVarInfo.append(CurInfoForComponentLists);
8864
+ };
8865
+
8866
+ GenerateInfoForComponentLists(DeclComponentLists,
8867
+ /*IsEligibleForTargetParamFlag=*/true);
8868
+ }
8869
+
8870
+ /// Generate the base pointers, section pointers, sizes, map types, and
8871
+ /// mappers associated to \a DeclComponentLists for a given capture
8872
+ /// \a VD (all included in \a CurComponentListInfo).
8873
+ void generateInfoForCaptureFromComponentLists(
8874
+ const ValueDecl *VD, ArrayRef<MapData> DeclComponentLists,
8875
+ MapCombinedInfoTy &CurComponentListInfo, StructRangeInfoTy &PartialStruct,
8876
+ bool IsListEligibleForTargetParamFlag,
8877
+ bool AreBothBasePtrAndPteeMapped = false) const {
8829
8878
// Find overlapping elements (including the offset from the base element).
8830
8879
llvm::SmallDenseMap<
8831
8880
const MapData *,
@@ -8949,7 +8998,7 @@ class MappableExprsHandler {
8949
8998
8950
8999
// Associated with a capture, because the mapping flags depend on it.
8951
9000
// Go through all of the elements with the overlapped elements.
8952
- bool IsFirstComponentList = true ;
9001
+ bool AddTargetParamFlag = IsListEligibleForTargetParamFlag ;
8953
9002
MapCombinedInfoTy StructBaseCombinedInfo;
8954
9003
for (const auto &Pair : OverlappedData) {
8955
9004
const MapData &L = *Pair.getFirst();
@@ -8964,11 +9013,11 @@ class MappableExprsHandler {
8964
9013
ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
8965
9014
OverlappedComponents = Pair.getSecond();
8966
9015
generateInfoForComponentList(
8967
- MapType, MapModifiers, {}, Components, CombinedInfo ,
8968
- StructBaseCombinedInfo, PartialStruct, IsFirstComponentList ,
8969
- IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
9016
+ MapType, MapModifiers, {}, Components, CurComponentListInfo ,
9017
+ StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag, IsImplicit ,
9018
+ /*GenerateAllInfoForClauses*/ false, Mapper,
8970
9019
/*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents);
8971
- IsFirstComponentList = false;
9020
+ AddTargetParamFlag = false;
8972
9021
}
8973
9022
// Go through other elements without overlapped elements.
8974
9023
for (const MapData &L : DeclComponentLists) {
@@ -8983,12 +9032,12 @@ class MappableExprsHandler {
8983
9032
auto It = OverlappedData.find(&L);
8984
9033
if (It == OverlappedData.end())
8985
9034
generateInfoForComponentList(
8986
- MapType, MapModifiers, {}, Components, CombinedInfo ,
8987
- StructBaseCombinedInfo, PartialStruct, IsFirstComponentList ,
9035
+ MapType, MapModifiers, {}, Components, CurComponentListInfo ,
9036
+ StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag ,
8988
9037
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
8989
9038
/*ForDeviceAddr=*/false, VD, VarRef,
8990
- /*OverlappedElements*/ {}, HasMapBasePtr && HasMapArraySec );
8991
- IsFirstComponentList = false;
9039
+ /*OverlappedElements*/ {}, AreBothBasePtrAndPteeMapped );
9040
+ AddTargetParamFlag = false;
8992
9041
}
8993
9042
}
8994
9043
@@ -9467,7 +9516,6 @@ static void genMapInfoForCaptures(
9467
9516
CE = CS.capture_end();
9468
9517
CI != CE; ++CI, ++RI, ++CV) {
9469
9518
MappableExprsHandler::MapCombinedInfoTy CurInfo;
9470
- MappableExprsHandler::StructRangeInfoTy PartialStruct;
9471
9519
9472
9520
// VLA sizes are passed to the outlined region by copy and do not have map
9473
9521
// information associated.
@@ -9488,37 +9536,33 @@ static void genMapInfoForCaptures(
9488
9536
} else {
9489
9537
// If we have any information in the map clause, we use it, otherwise we
9490
9538
// just do a default mapping.
9491
- MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
9539
+ MEHandler.generateInfoForCaptureFromClauseInfo(
9540
+ CI, *CV, CurInfo, OMPBuilder,
9541
+ /*OffsetForMemberOfFlag=*/CombinedInfo.BasePointers.size());
9542
+
9492
9543
if (!CI->capturesThis())
9493
9544
MappedVarSet.insert(CI->getCapturedVar());
9494
9545
else
9495
9546
MappedVarSet.insert(nullptr);
9496
- if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
9547
+
9548
+ if (CurInfo.BasePointers.empty())
9497
9549
MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
9550
+
9498
9551
// Generate correct mapping for variables captured by reference in
9499
9552
// lambdas.
9500
9553
if (CI->capturesVariable())
9501
9554
MEHandler.generateInfoForLambdaCaptures(CI->getCapturedVar(), *CV,
9502
9555
CurInfo, LambdaPointers);
9503
9556
}
9504
9557
// We expect to have at least an element of information for this capture.
9505
- assert(( !CurInfo.BasePointers.empty() || PartialStruct.Base.isValid() ) &&
9558
+ assert(!CurInfo.BasePointers.empty() &&
9506
9559
"Non-existing map pointer for capture!");
9507
9560
assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
9508
9561
CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
9509
9562
CurInfo.BasePointers.size() == CurInfo.Types.size() &&
9510
9563
CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
9511
9564
"Inconsistent map information sizes!");
9512
9565
9513
- // If there is an entry in PartialStruct it means we have a struct with
9514
- // individual members mapped. Emit an extra combined entry.
9515
- if (PartialStruct.Base.isValid()) {
9516
- CombinedInfo.append(PartialStruct.PreliminaryMapData);
9517
- MEHandler.emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct,
9518
- CI->capturesThis(), OMPBuilder, nullptr,
9519
- /*NotTargetParams*/ false);
9520
- }
9521
-
9522
9566
// We need to append the results of this capture to what we already have.
9523
9567
CombinedInfo.append(CurInfo);
9524
9568
}
0 commit comments