Skip to content

Commit 2f7d361

Browse files
authored
[SYCLomatic] Support migration for 2 cub API (LoadDirectWarpStriped, StoreDirectWarpStriped ) and improve migration for 2 API (BlockLoad.Load/Store templated with 3 extra algorithms) (#2690)
Introduce new help function. Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent 710b2c2 commit 2f7d361

File tree

9 files changed

+655
-48
lines changed

9 files changed

+655
-48
lines changed

clang/lib/DPCT/RulesLang/RewriterSYCLcompat.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ SYCLCOMPAT_UNSUPPORT("cub::LoadDirectBlocked")
7272
SYCLCOMPAT_UNSUPPORT("cub::LoadDirectStriped")
7373
SYCLCOMPAT_UNSUPPORT("cub::StoreDirectBlocked")
7474
SYCLCOMPAT_UNSUPPORT("cub::StoreDirectStriped")
75+
SYCLCOMPAT_UNSUPPORT("cub::LoadDirectWarpStriped")
76+
SYCLCOMPAT_UNSUPPORT("cub::StoreDirectWarpStriped")
7577
SYCLCOMPAT_UNSUPPORT("cub::ShuffleDown")
7678
SYCLCOMPAT_UNSUPPORT("cub::ShuffleUp")
7779
SYCLCOMPAT_UNSUPPORT("cuPointerGetAttributes")

clang/lib/DPCT/RulesLangLib/CUB/RewriterClassMethods.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,26 +297,30 @@ RewriterMap dpct::createClassMethodsRewriterMap() {
297297
HEADER_INSERT_FACTORY(
298298
HeaderType::HT_DPCT_GROUP_Utils,
299299
CASE_FACTORY_ENTRY(
300-
CASE(makeCheckAnd(CheckArgCount(2), CheckCUBEnumTemplateArg(3)),
300+
CASE(CheckArgCount(2),
301301
MEMBER_CALL_FACTORY_ENTRY("cub::BlockLoad.Load",
302302
MemberExprBase(), false, "load",
303303
NDITEM, ARG(0), ARG(1))),
304-
CASE(makeCheckAnd(CheckArgCount(3), CheckCUBEnumTemplateArg(3)),
304+
CASE(CheckArgCount(3),
305305
MEMBER_CALL_FACTORY_ENTRY("cub::BlockLoad.Load",
306306
MemberExprBase(), false, "load",
307307
NDITEM, ARG(0), ARG(1), ARG(2))),
308+
CASE(CheckArgCount(4),
309+
MEMBER_CALL_FACTORY_ENTRY(
310+
"cub::BlockLoad.Load", MemberExprBase(), false, "load",
311+
NDITEM, ARG(0), ARG(1), ARG(2), ARG(3))),
308312
OTHERWISE(UNSUPPORT_FACTORY_ENTRY("cub::BlockLoad.Load",
309313
Diagnostics::API_NOT_MIGRATED,
310314
printCallExprPretty()))))
311315
// cub::BlockStore.Store
312316
HEADER_INSERT_FACTORY(
313317
HeaderType::HT_DPCT_GROUP_Utils,
314318
CASE_FACTORY_ENTRY(
315-
CASE(makeCheckAnd(CheckArgCount(2), CheckCUBEnumTemplateArg(3)),
319+
CASE(CheckArgCount(2),
316320
MEMBER_CALL_FACTORY_ENTRY("cub::BlockStore.Store",
317321
MemberExprBase(), false, "store",
318322
NDITEM, ARG(0), ARG(1))),
319-
CASE(makeCheckAnd(CheckArgCount(3), CheckCUBEnumTemplateArg(3)),
323+
CASE(CheckArgCount(3),
320324
MEMBER_CALL_FACTORY_ENTRY("cub::BlockStore.Store",
321325
MemberExprBase(), false, "store",
322326
NDITEM, ARG(0), ARG(1), ARG(2))),

clang/lib/DPCT/RulesLangLib/CUB/RewriterUtilityFunctions.cpp

Lines changed: 158 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -162,41 +162,167 @@ RewriterMap dpct::createUtilityFunctionsRewriterMap() {
162162
MEMBER_CALL_FACTORY_ENTRY("cub::RowMajorTid", NDITEM, /*IsArrow=*/false,
163163
"get_local_linear_id")
164164
// cub::LoadDirectBlocked
165-
HEADER_INSERT_FACTORY(
166-
HeaderType::HT_DPCT_GROUP_Utils,
167-
CALL_FACTORY_ENTRY(
168-
"cub::LoadDirectBlocked",
169-
CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() +
170-
"group::load_direct_blocked",
171-
0, 1, 2),
172-
NDITEM, ARG(1), ARG(2))))
165+
CASE_FACTORY_ENTRY(
166+
CASE(CheckArgCount(3),
167+
HEADER_INSERT_FACTORY(
168+
HeaderType::HT_DPCT_GROUP_Utils,
169+
CALL_FACTORY_ENTRY("cub::LoadDirectBlocked",
170+
CALL(PRETTY_TEMPLATED_CALLEE(
171+
MapNames::getDpctNamespace() +
172+
"group::load_direct_blocked",
173+
0, 1, 2),
174+
NDITEM, ARG(1), ARG(2))))),
175+
CASE(CheckArgCount(4),
176+
HEADER_INSERT_FACTORY(
177+
HeaderType::HT_DPCT_GROUP_Utils,
178+
CALL_FACTORY_ENTRY("cub::LoadDirectBlocked",
179+
CALL(PRETTY_TEMPLATED_CALLEE(
180+
MapNames::getDpctNamespace() +
181+
"group::load_direct_blocked",
182+
0, 1, 2),
183+
NDITEM, ARG(1), ARG(2), ARG(3))))),
184+
CASE(CheckArgCount(5),
185+
HEADER_INSERT_FACTORY(
186+
HeaderType::HT_DPCT_GROUP_Utils,
187+
CALL_FACTORY_ENTRY("cub::LoadDirectBlocked",
188+
CALL(PRETTY_TEMPLATED_CALLEE(
189+
MapNames::getDpctNamespace() +
190+
"group::load_direct_blocked",
191+
0, 1, 2, 3),
192+
NDITEM, ARG(1), ARG(2), ARG(3),
193+
ARG(4))))))
194+
173195
// cub::LoadDirectStriped
174-
HEADER_INSERT_FACTORY(
175-
HeaderType::HT_DPCT_GROUP_Utils,
176-
CALL_FACTORY_ENTRY(
177-
"cub::LoadDirectStriped",
178-
CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() +
179-
"group::load_direct_striped",
180-
1, 2, 3),
181-
NDITEM, ARG(1), ARG(2))))
196+
CASE_FACTORY_ENTRY(
197+
CASE(CheckArgCount(3),
198+
HEADER_INSERT_FACTORY(
199+
HeaderType::HT_DPCT_GROUP_Utils,
200+
CALL_FACTORY_ENTRY("cub::LoadDirectStriped",
201+
CALL(PRETTY_TEMPLATED_CALLEE(
202+
MapNames::getDpctNamespace() +
203+
"group::load_direct_striped",
204+
1, 2, 3),
205+
NDITEM, ARG(1), ARG(2))))),
206+
CASE(CheckArgCount(4),
207+
HEADER_INSERT_FACTORY(
208+
HeaderType::HT_DPCT_GROUP_Utils,
209+
CALL_FACTORY_ENTRY("cub::LoadDirectStriped",
210+
CALL(PRETTY_TEMPLATED_CALLEE(
211+
MapNames::getDpctNamespace() +
212+
"group::load_direct_striped",
213+
1, 2, 3),
214+
NDITEM, ARG(1), ARG(2), ARG(3))))),
215+
CASE(CheckArgCount(5),
216+
HEADER_INSERT_FACTORY(
217+
HeaderType::HT_DPCT_GROUP_Utils,
218+
CALL_FACTORY_ENTRY("cub::LoadDirectStriped",
219+
CALL(PRETTY_TEMPLATED_CALLEE(
220+
MapNames::getDpctNamespace() +
221+
"group::load_direct_striped",
222+
1, 2, 3, 4),
223+
NDITEM, ARG(1), ARG(2), ARG(3),
224+
ARG(4))))))
225+
// cub::LoadDirectWarpStriped
226+
CASE_FACTORY_ENTRY(
227+
CASE(CheckArgCount(3),
228+
HEADER_INSERT_FACTORY(
229+
HeaderType::HT_DPCT_GROUP_Utils,
230+
CALL_FACTORY_ENTRY(
231+
"cub::LoadDirectWarpStriped",
232+
CALL(PRETTY_TEMPLATED_CALLEE(
233+
MapNames::getDpctNamespace() +
234+
"group::load_direct_sub_group_striped",
235+
0, 1, 2),
236+
NDITEM, ARG(1), ARG(2))))),
237+
CASE(CheckArgCount(4),
238+
HEADER_INSERT_FACTORY(
239+
HeaderType::HT_DPCT_GROUP_Utils,
240+
CALL_FACTORY_ENTRY(
241+
"cub::LoadDirectWarpStriped",
242+
CALL(PRETTY_TEMPLATED_CALLEE(
243+
MapNames::getDpctNamespace() +
244+
"group::load_direct_sub_group_striped",
245+
0, 1, 2),
246+
NDITEM, ARG(1), ARG(2), ARG(3))))),
247+
CASE(CheckArgCount(5),
248+
HEADER_INSERT_FACTORY(
249+
HeaderType::HT_DPCT_GROUP_Utils,
250+
CALL_FACTORY_ENTRY(
251+
"cub::LoadDirectWarpStriped",
252+
CALL(PRETTY_TEMPLATED_CALLEE(
253+
MapNames::getDpctNamespace() +
254+
"group::load_direct_sub_group_striped",
255+
0, 1, 2, 3),
256+
NDITEM, ARG(1), ARG(2), ARG(3), ARG(4))))))
257+
182258
// cub::StoreDirectBlocked
183-
HEADER_INSERT_FACTORY(
184-
HeaderType::HT_DPCT_GROUP_Utils,
185-
CALL_FACTORY_ENTRY(
186-
"cub::StoreDirectBlocked",
187-
CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() +
188-
"group::store_direct_blocked",
189-
0, 1, 2),
190-
NDITEM, ARG(1), ARG(2))))
259+
CASE_FACTORY_ENTRY(
260+
CASE(
261+
CheckArgCount(3),
262+
HEADER_INSERT_FACTORY(
263+
HeaderType::HT_DPCT_GROUP_Utils,
264+
CALL_FACTORY_ENTRY("cub::StoreDirectBlocked",
265+
CALL(PRETTY_TEMPLATED_CALLEE(
266+
MapNames::getDpctNamespace() +
267+
"group::store_direct_blocked",
268+
0, 1, 2),
269+
NDITEM, ARG(1), ARG(2))))),
270+
CASE(
271+
CheckArgCount(4),
272+
HEADER_INSERT_FACTORY(
273+
HeaderType::HT_DPCT_GROUP_Utils,
274+
CALL_FACTORY_ENTRY("cub::StoreDirectBlocked",
275+
CALL(PRETTY_TEMPLATED_CALLEE(
276+
MapNames::getDpctNamespace() +
277+
"group::store_direct_blocked",
278+
0, 1, 2),
279+
NDITEM, ARG(1), ARG(2), ARG(3))))))
280+
191281
// cub::StoreDirectStriped
192-
HEADER_INSERT_FACTORY(
193-
HeaderType::HT_DPCT_GROUP_Utils,
194-
CALL_FACTORY_ENTRY(
195-
"cub::StoreDirectStriped",
196-
CALL(PRETTY_TEMPLATED_CALLEE(MapNames::getDpctNamespace() +
197-
"group::store_direct_striped",
198-
1, 2, 3),
199-
NDITEM, ARG(1), ARG(2))))
282+
CASE_FACTORY_ENTRY(
283+
CASE(
284+
CheckArgCount(3),
285+
HEADER_INSERT_FACTORY(
286+
HeaderType::HT_DPCT_GROUP_Utils,
287+
CALL_FACTORY_ENTRY("cub::StoreDirectStriped",
288+
CALL(PRETTY_TEMPLATED_CALLEE(
289+
MapNames::getDpctNamespace() +
290+
"group::store_direct_striped",
291+
1, 2, 3),
292+
NDITEM, ARG(1), ARG(2))))),
293+
CASE(
294+
CheckArgCount(4),
295+
HEADER_INSERT_FACTORY(
296+
HeaderType::HT_DPCT_GROUP_Utils,
297+
CALL_FACTORY_ENTRY("cub::StoreDirectStriped",
298+
CALL(PRETTY_TEMPLATED_CALLEE(
299+
MapNames::getDpctNamespace() +
300+
"group::store_direct_striped",
301+
1, 2, 3),
302+
NDITEM, ARG(1), ARG(2), ARG(3))))))
303+
// cub::StoreDirectWarpStriped
304+
CASE_FACTORY_ENTRY(
305+
CASE(CheckArgCount(3),
306+
HEADER_INSERT_FACTORY(
307+
HeaderType::HT_DPCT_GROUP_Utils,
308+
CALL_FACTORY_ENTRY(
309+
"cub::StoreDirectWarpStriped",
310+
CALL(PRETTY_TEMPLATED_CALLEE(
311+
MapNames::getDpctNamespace() +
312+
"group::store_direct_sub_group_striped",
313+
0, 1, 2),
314+
NDITEM, ARG(1), ARG(2))))),
315+
CASE(CheckArgCount(4),
316+
HEADER_INSERT_FACTORY(
317+
HeaderType::HT_DPCT_GROUP_Utils,
318+
CALL_FACTORY_ENTRY(
319+
"cub::StoreDirectWarpStriped",
320+
CALL(PRETTY_TEMPLATED_CALLEE(
321+
MapNames::getDpctNamespace() +
322+
"group::store_direct_sub_group_striped",
323+
0, 1, 2),
324+
NDITEM, ARG(1), ARG(2), ARG(3))))))
325+
200326
// cub::ShuffleDown
201327
SUBGROUPSIZE_FACTORY(
202328
UINT_MAX,

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ void CubIntrinsicRule::registerMatcher(ast_matchers::MatchFinder &MF) {
332332
"SmVersionUncached", "RowMajorTid",
333333
"LoadDirectBlocked", "LoadDirectStriped",
334334
"StoreDirectBlocked", "StoreDirectStriped",
335-
"ShuffleDown", "ShuffleUp", "Debug"),
335+
"ShuffleDown", "ShuffleUp", "Debug",
336+
"LoadDirectWarpStriped", "StoreDirectWarpStriped"),
336337
hasAncestor(namespaceDecl(hasName("cub")))))))
337338
.bind("IntrinsicCall"),
338339
this);

clang/lib/DPCT/RulesLangLib/MapNamesLangLib.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,14 @@ void MapNamesLangLib::setExplicitNamespaceMap(
2929
CUBEnumsMap = {
3030
{"BLOCK_STORE_DIRECT", MapNames::getDpctNamespace() + "group::group_store_algorithm::blocked"},
3131
{"BLOCK_STORE_STRIPED", MapNames::getDpctNamespace() + "group::group_store_algorithm::striped"},
32+
{"BLOCK_STORE_VECTORIZE", MapNames::getDpctNamespace() + "group::group_store_algorithm::blocked"},
33+
{"BLOCK_STORE_TRANSPOSE", MapNames::getDpctNamespace() + "group::group_store_algorithm::transpose"},
34+
{"BLOCK_STORE_WARP_TRANSPOSE", MapNames::getDpctNamespace() + "group::group_store_algorithm::sub_group_transpose"},
3235
{"BLOCK_LOAD_DIRECT", MapNames::getDpctNamespace() + "group::group_load_algorithm::blocked"},
33-
{"BLOCK_LOAD_STRIPED", MapNames::getDpctNamespace() + "group::group_load_algorithm::striped"}
36+
{"BLOCK_LOAD_STRIPED", MapNames::getDpctNamespace() + "group::group_load_algorithm::striped"},
37+
{"BLOCK_LOAD_VECTORIZE", MapNames::getDpctNamespace() + "group::group_load_algorithm::blocked"},
38+
{"BLOCK_LOAD_TRANSPOSE", MapNames::getDpctNamespace() + "group::group_load_algorithm::transpose"},
39+
{"BLOCK_LOAD_WARP_TRANSPOSE", MapNames::getDpctNamespace() + "group::group_load_algorithm::sub_group_transpose"},
3440
};
3541
// clang-format on
3642

clang/lib/DPCT/SrcAPI/APINames_CUB.inc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ ENTRY(cub::StoreDirectBlocked, cub::StoreDirectBlocked, true, NO_FLAG, P4, "Succ
234234
ENTRY(cub::StoreDirectBlockedVectorized, cub::StoreDirectBlockedVectorized, false, NO_FLAG, P4, "Comment")
235235
ENTRY(cub::LoadDirectStriped, cub::LoadDirectStriped, true, NO_FLAG, P4, "Successful")
236236
ENTRY(cub::StoreDirectStriped, cub::StoreDirectStriped, true, NO_FLAG, P4, "Successful")
237-
ENTRY(cub::LoadDirectWarpStriped, cub::LoadDirectWarpStriped, false, NO_FLAG, P4, "Comment")
238-
ENTRY(cub::StoreDirectWarpStriped, cub::StoreDirectWarpStriped, false, NO_FLAG, P4, "Comment")
237+
ENTRY(cub::LoadDirectWarpStriped, cub::LoadDirectWarpStriped, true, NO_FLAG, P4, "Comment")
238+
ENTRY(cub::StoreDirectWarpStriped, cub::StoreDirectWarpStriped, true, NO_FLAG, P4, "Comment")
239239

240240
// PTX intrinsics
241241
ENTRY(cub::SHR_ADD, cub::SHR_ADD, true, NO_FLAG, P4, "Successful")

0 commit comments

Comments
 (0)