Skip to content

Commit 2d33dba

Browse files
authored
[SYCLomatic] Enable the migration for 4 CUB API BlockShuffle::Offset/Rotate/Up/Down with new introduced helper class group_shuffle(#2613)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent 3198195 commit 2d33dba

File tree

10 files changed

+350
-19
lines changed

10 files changed

+350
-19
lines changed

clang/lib/DPCT/AnalysisInfo.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3812,6 +3812,7 @@ void TempStorageVarInfo::addAccessorDecl(StmtList &AccessorList,
38123812
OS << '(' << LocalSize << ".size() * sizeof("
38133813
<< ValueType->getSourceString() << ')' << ')';
38143814
break;
3815+
case BlockShuffle:
38153816
case BlockRadixSort:
38163817
OS << MapNames::getClNamespace() << "local_accessor<uint8_t, 1> " << Name
38173818
<< "_acc(";
@@ -3831,6 +3832,7 @@ ParameterStream &TempStorageVarInfo::getFuncDecl(ParameterStream &PS) {
38313832
case BlockReduce:
38323833
PS << MapNames::getClNamespace() << "local_accessor<std::byte, 1> ";
38333834
break;
3835+
case BlockShuffle:
38343836
case BlockRadixSort:
38353837
PS << "uint8_t *";
38363838
break;

clang/lib/DPCT/AnalysisInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,7 @@ class TempStorageVarInfo {
22372237
enum APIKind {
22382238
BlockReduce,
22392239
BlockRadixSort,
2240+
BlockShuffle,
22402241
};
22412242

22422243
private:

clang/lib/DPCT/RuleInfra/APINamesTemplateType.inc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,31 @@ TYPE_REWRITE_ENTRY(
131131
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
132132
"group::exchange"),
133133
TEMPLATE_ARG(0), TEMPLATE_ARG(2)))))
134+
135+
// cub::BlockShuffle
136+
TYPE_REWRITE_ENTRY(
137+
"cub::BlockShuffle",
138+
TYPE_CONDITIONAL_FACTORY(
139+
UseSYCLCompat(),
140+
WARNING_FACTORY(Diagnostics::UNSUPPORT_SYCLCOMPAT, TYPESTR),
141+
HEADER_INSERTION_FACTORY(
142+
HeaderType::HT_DPCT_GROUP_Utils,
143+
TYPE_CONDITIONAL_FACTORY(
144+
CheckTemplateArgCount(2, false),
145+
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
146+
"group::group_shuffle"),
147+
TEMPLATE_ARG(0), TEMPLATE_ARG(1)),
148+
TYPE_CONDITIONAL_FACTORY(
149+
CheckTemplateArgCount(3, false),
150+
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
151+
"group::group_shuffle"),
152+
TEMPLATE_ARG(0), TEMPLATE_ARG(1),
153+
TEMPLATE_ARG(2)),
154+
TYPE_FACTORY(STR(MapNames::getLibraryHelperNamespace() +
155+
"group::group_shuffle"),
156+
TEMPLATE_ARG(0), TEMPLATE_ARG(1),
157+
TEMPLATE_ARG(2), TEMPLATE_ARG(3)))))))
158+
134159
// cub::BlockLoad
135160
TYPE_REWRITE_ENTRY(
136161
"cub::BlockLoad",

clang/lib/DPCT/RuleInfra/TypeLocRewriters.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,24 @@ makeUserDefinedTypeStrCreator(MetaRuleObject &R,
102102

103103
class CheckTemplateArgCount {
104104
unsigned Count;
105+
bool IsIncludeDefault;
105106

106107
public:
107-
CheckTemplateArgCount(unsigned I) : Count(I) {}
108+
CheckTemplateArgCount(unsigned I, bool D = true)
109+
: Count(I), IsIncludeDefault(D) {}
108110
bool operator()(const TypeLoc TL) {
109-
if(auto TSTL = TL.getAs<TemplateSpecializationTypeLoc>()){
110-
return TSTL.getNumArgs() == Count;
111+
if (auto TSTL = TL.getAs<TemplateSpecializationTypeLoc>()) {
112+
size_t Num = TSTL.getNumArgs();
113+
if (IsIncludeDefault) {
114+
return Num == Count;
115+
}
116+
size_t NoneDefaultNum = 0;
117+
for (int i = 0; i < Num; i++) {
118+
if (!TSTL.getArgLoc(i).getArgument().getIsDefaulted()) {
119+
NoneDefaultNum++;
120+
}
121+
}
122+
return NoneDefaultNum == Count;
111123
}
112124
return false;
113125
}

clang/lib/DPCT/RulesLang/RewriterSYCLcompat.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.BlockedToStriped")
9494
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.StripedToBlocked")
9595
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.ScatterToBlocked")
9696
SYCLCOMPAT_UNSUPPORT("cub::BlockExchange.ScatterToStriped")
97+
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Offset")
98+
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Rotate")
99+
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Up")
100+
SYCLCOMPAT_UNSUPPORT("cub::BlockShuffle.Down")
97101
SYCLCOMPAT_UNSUPPORT("cub::BlockLoad.Load")
98102
SYCLCOMPAT_UNSUPPORT("cub::BlockStore.Store")
99103
});

clang/lib/DPCT/RulesLangLib/CUB/RewriterClassMethods.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,66 @@ RewriterMap dpct::createClassMethodsRewriterMap() {
211211
"cub::BlockExchange.ScatterToStriped",
212212
MemberExprBase(), false, "scatter_to_striped",
213213
NDITEM, ARG(0), ARG(1)))
214+
// cub::BlockShuffle.Offset
215+
HEADER_INSERT_FACTORY(
216+
HeaderType::HT_DPCT_GROUP_Utils,
217+
CASE_FACTORY_ENTRY(
218+
CASE(CheckArgCount(2, std::equal_to<>(), false),
219+
MEMBER_CALL_FACTORY_ENTRY("cub::BlockShuffle.Offset",
220+
MemberExprBase(), false, "select",
221+
NDITEM, ARG(0), ARG(1))),
222+
CASE(CheckArgCount(3, std::equal_to<>(), false),
223+
MEMBER_CALL_FACTORY_ENTRY("cub::BlockShuffle.Offset",
224+
MemberExprBase(), false, "select",
225+
NDITEM, ARG(0), ARG(1), ARG(2))),
226+
OTHERWISE(UNSUPPORT_FACTORY_ENTRY("cub::BlockShuffle.Offset",
227+
Diagnostics::API_NOT_MIGRATED,
228+
printCallExprPretty()))))
229+
// cub::BlockShuffle.Rotate
230+
HEADER_INSERT_FACTORY(
231+
HeaderType::HT_DPCT_GROUP_Utils,
232+
CASE_FACTORY_ENTRY(
233+
CASE(CheckArgCount(2, std::equal_to<>(), false),
234+
MEMBER_CALL_FACTORY_ENTRY("cub::BlockShuffle.Rotate",
235+
MemberExprBase(), false, "select2",
236+
NDITEM, ARG(0), ARG(1))),
237+
CASE(CheckArgCount(3, std::equal_to<>(), false),
238+
MEMBER_CALL_FACTORY_ENTRY("cub::BlockShuffle.Rotate",
239+
MemberExprBase(), false, "select2",
240+
NDITEM, ARG(0), ARG(1), ARG(2))),
241+
OTHERWISE(UNSUPPORT_FACTORY_ENTRY("cub::BlockShuffle.Rotate",
242+
Diagnostics::API_NOT_MIGRATED,
243+
printCallExprPretty()))))
244+
// cub::BlockShuffle.Up
245+
HEADER_INSERT_FACTORY(
246+
HeaderType::HT_DPCT_GROUP_Utils,
247+
CASE_FACTORY_ENTRY(
248+
CASE(CheckArgCount(2),
249+
MEMBER_CALL_FACTORY_ENTRY(
250+
"cub::BlockShuffle.Up", MemberExprBase(), false,
251+
"shuffle_right", NDITEM, ARG(0), ARG(1))),
252+
CASE(CheckArgCount(3),
253+
MEMBER_CALL_FACTORY_ENTRY(
254+
"cub::BlockShuffle.Up", MemberExprBase(), false,
255+
"shuffle_right", NDITEM, ARG(0), ARG(1), ARG(2))),
256+
OTHERWISE(UNSUPPORT_FACTORY_ENTRY("cub::BlockShuffle.Up",
257+
Diagnostics::API_NOT_MIGRATED,
258+
printCallExprPretty()))))
259+
// cub::BlockShuffle.Down
260+
HEADER_INSERT_FACTORY(
261+
HeaderType::HT_DPCT_GROUP_Utils,
262+
CASE_FACTORY_ENTRY(
263+
CASE(CheckArgCount(2),
264+
MEMBER_CALL_FACTORY_ENTRY(
265+
"cub::BlockShuffle.Down", MemberExprBase(), false,
266+
"shuffle_left", NDITEM, ARG(0), ARG(1))),
267+
CASE(CheckArgCount(3),
268+
MEMBER_CALL_FACTORY_ENTRY(
269+
"cub::BlockShuffle.Down", MemberExprBase(), false,
270+
"shuffle_left", NDITEM, ARG(0), ARG(1), ARG(2))),
271+
OTHERWISE(UNSUPPORT_FACTORY_ENTRY("cub::BlockShuffle.Down",
272+
Diagnostics::API_NOT_MIGRATED,
273+
printCallExprPretty()))))
214274
// cub::BlockLoad.Load
215275
HEADER_INSERT_FACTORY(
216276
HeaderType::HT_DPCT_GROUP_Utils,

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void CubTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
100100
"cub::ArgIndexInputIterator", "cub::DiscardOutputIterator",
101101
"cub::DoubleBuffer", "cub::NullType", "cub::ArgMax", "cub::ArgMin",
102102
"cub::BlockRadixSort", "cub::BlockExchange", "cub::BlockLoad",
103-
"cub::BlockStore");
103+
"cub::BlockStore", "cub::BlockShuffle");
104104
};
105105

106106
MF.addMatcher(
@@ -158,15 +158,16 @@ void CubDeviceLevelRule::runRule(
158158
void CubMemberCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
159159
MF.addMatcher(
160160
cxxMemberCallExpr(
161-
allOf(on(hasType(hasCanonicalType(qualType(hasDeclaration(namedDecl(
162-
hasAnyName("cub::ArgIndexInputIterator",
163-
"cub::BlockRadixSort", "cub::BlockExchange",
164-
"cub::BlockLoad", "cub::BlockStore"))))))),
161+
allOf(on(hasType(hasCanonicalType(
162+
qualType(hasDeclaration(namedDecl(hasAnyName(
163+
"cub::ArgIndexInputIterator", "cub::BlockRadixSort",
164+
"cub::BlockExchange", "cub::BlockLoad",
165+
"cub::BlockStore", "cub::BlockShuffle"))))))),
165166
callee(cxxMethodDecl(hasAnyName(
166167
"normalize", "Sort", "SortDescending", "BlockedToStriped",
167168
"StripedToBlocked", "ScatterToBlocked", "ScatterToStriped",
168169
"SortBlockedToStriped", "SortDescendingBlockedToStriped",
169-
"Load", "Store")))))
170+
"Load", "Store", "Offset", "Rotate", "Up", "Down")))))
170171
.bind("memberCall"),
171172
this);
172173

@@ -253,13 +254,17 @@ void CubMemberCallRule::runRule(
253254
Name == "BlockedToStriped" || Name == "StripedToBlocked" ||
254255
Name == "StripedToBlocked" || Name == "ScatterToBlocked" ||
255256
Name == "ScatterToStriped";
256-
if (isBlockRadixSort || isBlockExchange || Name == "Load" ||
257-
Name == "Store") {
257+
bool isBlockShuffle =
258+
Name == "Offset" || Name == "Rotate" || Name == "Up" || Name == "Down";
259+
if (isBlockRadixSort || isBlockExchange || isBlockShuffle ||
260+
Name == "Load" || Name == "Store") {
258261
std::string HelpFuncName;
259262
if (isBlockRadixSort)
260263
HelpFuncName = "group_radix_sort";
261264
else if (isBlockExchange)
262265
HelpFuncName = "exchange";
266+
else if (isBlockShuffle)
267+
HelpFuncName = "group_shuffle";
263268
else if (Name == "Load")
264269
HelpFuncName = "group_load";
265270
else if (Name == "Store")
@@ -273,20 +278,36 @@ void CubMemberCallRule::runRule(
273278
auto *ClassSpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(
274279
CanTy->getAs<RecordType>()->getDecl());
275280
const auto &ValueTyArg = ClassSpecDecl->getTemplateArgs()[0];
276-
const auto &ItemsPreThreadArg = ClassSpecDecl->getTemplateArgs()[2];
281+
277282
ValueTyArg.getAsType().getAsString();
278283
std::string Fn;
279284
llvm::raw_string_ostream OS(Fn);
280285
OS << MapNames::getDpctNamespace() << "group::" << HelpFuncName << "<"
281-
<< ValueTyArg.getAsType().getAsString() << ", "
282-
<< ItemsPreThreadArg.getAsIntegral() << ">::get_local_memory_size";
286+
<< ValueTyArg.getAsType().getAsString();
287+
if (isBlockShuffle) {
288+
if (!ClassSpecDecl->getTemplateArgs()[1].getIsDefaulted()) {
289+
OS << ", " << ClassSpecDecl->getTemplateArgs()[1].getAsIntegral();
290+
}
291+
if (!ClassSpecDecl->getTemplateArgs()[2].getIsDefaulted()) {
292+
OS << ", " << ClassSpecDecl->getTemplateArgs()[2].getAsIntegral();
293+
}
294+
if (!ClassSpecDecl->getTemplateArgs()[3].getIsDefaulted()) {
295+
OS << ", " << ClassSpecDecl->getTemplateArgs()[3].getAsIntegral();
296+
}
297+
} else {
298+
const auto &ItemsPreThreadArg = ClassSpecDecl->getTemplateArgs()[2];
299+
OS << ", " << ItemsPreThreadArg.getAsIntegral();
300+
}
301+
OS << ">::get_local_memory_size";
283302
if (auto FuncInfo = DeviceFunctionDecl::LinkRedecls(FD)) {
284303
auto LocInfo = DpctGlobalInfo::getLocInfo(TempStorage);
285304
ExprAnalysis EA;
286305
EA.analyze(DataTypeLoc);
287306
FuncInfo->getVarMap().addCUBTempStorage(
288307
std::make_shared<TempStorageVarInfo>(
289-
LocInfo.second, TempStorageVarInfo::BlockRadixSort,
308+
LocInfo.second,
309+
isBlockShuffle ? TempStorageVarInfo::BlockShuffle
310+
: TempStorageVarInfo::BlockRadixSort,
290311
TempStorage->getName(), Fn,
291312
EA.getTemplateDependentStringInfo()));
292313
}

clang/lib/DPCT/SrcAPI/APINames_CUB.inc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ ENTRY_MEMBER_FUNCTION(cub::BlockScan, cub::BlockScan, ExclusiveSum, ExclusiveSum
119119
ENTRY_MEMBER_FUNCTION(cub::BlockScan, cub::BlockScan, ExclusiveScan, ExclusiveScan, true, NO_FLAG, P4, "Successful")
120120
ENTRY_MEMBER_FUNCTION(cub::BlockScan, cub::BlockScan, InclusiveSum, InclusiveSum, true, NO_FLAG, P4, "Successful")
121121
ENTRY_MEMBER_FUNCTION(cub::BlockScan, cub::BlockScan, InclusiveScan, InclusiveScan, true, NO_FLAG, P4, "Successful")
122-
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Offset, Offset, false, NO_FLAG, P4, "Comment")
123-
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Rotate, Rotate, false, NO_FLAG, P4, "Comment")
124-
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Up, Up, false, NO_FLAG, P4, "Comment")
125-
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Down, Down, false, NO_FLAG, P4, "Comment")
122+
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Offset, Offset, true, NO_FLAG, P4, "Comment")
123+
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Rotate, Rotate, true, NO_FLAG, P4, "Comment")
124+
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Up, Up, true, NO_FLAG, P4, "Comment")
125+
ENTRY_MEMBER_FUNCTION(cub::BlockShuffle, cub::BlockShuffle, Down, Down, true, NO_FLAG, P4, "Comment")
126126
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagHeads, FlagHeads, false, NO_FLAG, P4, "Comment")
127127
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagTails, FlagTails, false, NO_FLAG, P4, "Comment")
128128
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagHeadsAndTails, FlagHeadsAndTails, false, NO_FLAG, P4, "Comment")

0 commit comments

Comments
 (0)