Skip to content

Commit c97f93a

Browse files
authored
[SYCLomatic] Always add "template" before sycl::vec::convert for 4 APIs (#2817)
cuComplexDoubleToFloat cuComplexDoubleToFloat __float22half2_rn __half22float2 Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent d5dc0f1 commit c97f93a

File tree

9 files changed

+62
-119
lines changed

9 files changed

+62
-119
lines changed

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,24 +1038,18 @@ template <class NameT> class TypeNamePrinter {
10381038
}
10391039
};
10401040

1041-
template <class BaseT, class MemberT, bool HasExplicitTemplateArg>
1041+
template <class BaseT, class MemberT, bool NeedDisambiguator>
10421042
class MemberExprPrinter {
10431043
BaseT Base;
10441044
bool IsArrow;
10451045
MemberT MemberName;
1046-
bool IsBaseDependentType = false;
10471046

10481047
public:
10491048
MemberExprPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName)
1050-
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {
1051-
if constexpr (std::is_same_v<BaseT, const Expr *>) {
1052-
IsBaseDependentType = Base->getType()->isDependentType();
1053-
}
1054-
}
1049+
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {}
10551050

10561051
template <class StreamT> void print(StreamT &Stream) const {
1057-
printBase(Stream, Base, IsArrow,
1058-
HasExplicitTemplateArg && IsBaseDependentType);
1052+
printBase(Stream, Base, IsArrow, NeedDisambiguator);
10591053
dpct::print(Stream, MemberName);
10601054
}
10611055
};
@@ -1074,19 +1068,17 @@ template <class BaseT, class MemberT> class StaticMemberExprPrinter {
10741068
}
10751069
};
10761070

1077-
template <class BaseT, class MemberT, bool HasExplicitTemplateArg,
1071+
template <class BaseT, class MemberT, bool NeedDisambiguator,
10781072
class... CallArgsT>
10791073
class MemberCallPrinter
10801074
: public CallExprPrinter<
1081-
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
1082-
CallArgsT...> {
1075+
MemberExprPrinter<BaseT, MemberT, NeedDisambiguator>, CallArgsT...> {
10831076
public:
10841077
MemberCallPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName,
10851078
CallArgsT &&...Args)
1086-
: CallExprPrinter<
1087-
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
1088-
CallArgsT...>(
1089-
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>(
1079+
: CallExprPrinter<MemberExprPrinter<BaseT, MemberT, NeedDisambiguator>,
1080+
CallArgsT...>(
1081+
MemberExprPrinter<BaseT, MemberT, NeedDisambiguator>(
10901082
std::move(Base), IsArrow, std::move(MemberName)),
10911083
std::forward<CallArgsT>(Args)...) {}
10921084
};
@@ -1451,25 +1443,25 @@ class MemberExprRewriter
14511443
C, Source, BaseCreator(C), IsArrow, MemberCreator(C)) {}
14521444
};
14531445

1454-
template <class BaseT, bool HasExplicitTemplateArg, class... ArgsT>
1446+
template <class BaseT, bool NeedDisambiguator, class... ArgsT>
14551447
class MemberCallExprRewriter
1456-
: public PrinterRewriter<MemberCallPrinter<
1457-
BaseT, StringRef, HasExplicitTemplateArg, ArgsT...>> {
1448+
: public PrinterRewriter<
1449+
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, ArgsT...>> {
14581450
public:
14591451
MemberCallExprRewriter(
14601452
const CallExpr *C, StringRef Source,
14611453
const std::function<BaseT(const CallExpr *)> &BaseCreator, bool IsArrow,
14621454
StringRef Member,
14631455
const std::function<ArgsT(const CallExpr *)> &...ArgsCreator)
1464-
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef,
1465-
HasExplicitTemplateArg, ArgsT...>>(
1456+
: PrinterRewriter<
1457+
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, ArgsT...>>(
14661458
C, Source, BaseCreator(C), IsArrow, Member, ArgsCreator(C)...) {}
14671459
MemberCallExprRewriter(
14681460
const CallExpr *C, StringRef Source, const BaseT &BaseCreator,
14691461
bool IsArrow, StringRef Member,
14701462
const std::function<ArgsT(const CallExpr *)> &...ArgsCreator)
1471-
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef,
1472-
HasExplicitTemplateArg, ArgsT...>>(
1463+
: PrinterRewriter<
1464+
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, ArgsT...>>(
14731465
C, Source, BaseCreator, IsArrow, Member, ArgsCreator(C)...) {}
14741466
};
14751467

clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -439,32 +439,32 @@ inline std::function<std::string(const CallExpr *)> makeDeviceStr() {
439439
};
440440
}
441441

442-
template <class BaseT, bool HasExplicitTemplateArg, class... CallArgsT>
442+
template <class BaseT, bool NeedDisambiguator, class... CallArgsT>
443443
using MemberCallPrinterCreator = PrinterCreator<
444-
MemberCallPrinter<BaseT, StringRef, HasExplicitTemplateArg, CallArgsT...>,
444+
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, CallArgsT...>,
445445
std::function<BaseT(const CallExpr *)>, bool, std::string,
446446
std::function<CallArgsT(const CallExpr *)>...>;
447447

448-
template <bool HasExplicitTemplateArg, class BaseT, class... CallArgsT>
449-
inline std::function<MemberCallPrinter<BaseT, StringRef, HasExplicitTemplateArg,
448+
template <bool NeedDisambiguator, class BaseT, class... CallArgsT>
449+
inline std::function<MemberCallPrinter<BaseT, StringRef, NeedDisambiguator,
450450
CallArgsT...>(const CallExpr *)>
451451
makeMemberCallCreator(std::function<BaseT(const CallExpr *)> BaseFunc,
452452
bool IsArrow, std::string Member,
453453
std::function<CallArgsT(const CallExpr *)>... Args) {
454-
return MemberCallPrinterCreator<BaseT, HasExplicitTemplateArg, CallArgsT...>(
454+
return MemberCallPrinterCreator<BaseT, NeedDisambiguator, CallArgsT...>(
455455
BaseFunc, IsArrow, Member, Args...);
456456
}
457457

458-
template <bool HasExplicitTemplateArg, class BaseT, class MemberT>
458+
template <bool NeedDisambiguator, class BaseT, class MemberT>
459459
inline std::function<
460-
MemberCallPrinter<BaseT, MemberT, HasExplicitTemplateArg>(const CallExpr *)>
460+
MemberCallPrinter<BaseT, MemberT, NeedDisambiguator>(const CallExpr *)>
461461
makeMemberCallCreator(std::function<BaseT(const CallExpr *)> BaseFunc,
462462
bool IsArrow,
463463
std::function<MemberT(const CallExpr *)> Member) {
464-
return PrinterCreator<
465-
MemberCallPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
466-
std::function<BaseT(const CallExpr *)>, bool,
467-
std::function<MemberT(const CallExpr *)>>(BaseFunc, IsArrow, Member);
464+
return PrinterCreator<MemberCallPrinter<BaseT, MemberT, NeedDisambiguator>,
465+
std::function<BaseT(const CallExpr *)>, bool,
466+
std::function<MemberT(const CallExpr *)>>(
467+
BaseFunc, IsArrow, Member);
468468
}
469469

470470
template <class... StmtT>
@@ -1344,15 +1344,15 @@ createTemplatedCallExprRewriterFactory(
13441344
/// \p BaseCreator use to get base expr from original call expr.
13451345
/// \p IsArrow the member operator is arrow or dot as default.
13461346
/// \p ArgsCreator use to get call args from original call expr.
1347-
template <bool HasExplicitTemplateArg, class BaseT, class... ArgsT>
1347+
template <bool NeedDisambiguator, class BaseT, class... ArgsT>
13481348
inline std::shared_ptr<CallExprRewriterFactoryBase>
13491349
createMemberCallExprRewriterFactory(
13501350
const std::string &SourceName,
13511351
std::function<BaseT(const CallExpr *)> BaseCreator, bool IsArrow,
13521352
std::string MemberName,
13531353
std::function<ArgsT(const CallExpr *)>... ArgsCreator) {
13541354
return std::make_shared<CallExprRewriterFactory<
1355-
MemberCallExprRewriter<BaseT, HasExplicitTemplateArg, ArgsT...>,
1355+
MemberCallExprRewriter<BaseT, NeedDisambiguator, ArgsT...>,
13561356
std::function<BaseT(const CallExpr *)>, bool, std::string,
13571357
std::function<ArgsT(const CallExpr *)>...>>(
13581358
SourceName,
@@ -1361,16 +1361,16 @@ createMemberCallExprRewriterFactory(
13611361
std::forward<std::function<ArgsT(const CallExpr *)>>(ArgsCreator)...);
13621362
}
13631363

1364-
template <bool HasExplicitTemplateArg, class BaseT, class... ArgsT>
1364+
template <bool NeedDisambiguator, class BaseT, class... ArgsT>
13651365
inline std::shared_ptr<std::enable_if_t<
13661366
!std::is_invocable_v<BaseT, const CallExpr *>, CallExprRewriterFactoryBase>>
13671367
createMemberCallExprRewriterFactory(
13681368
const std::string &SourceName, BaseT BaseCreator, bool IsArrow,
13691369
std::string MemberName,
13701370
std::function<ArgsT(const CallExpr *)>... ArgsCreator) {
13711371
return std::make_shared<CallExprRewriterFactory<
1372-
MemberCallExprRewriter<BaseT, HasExplicitTemplateArg, ArgsT...>, BaseT,
1373-
bool, std::string, std::function<ArgsT(const CallExpr *)>...>>(
1372+
MemberCallExprRewriter<BaseT, NeedDisambiguator, ArgsT...>, BaseT, bool,
1373+
std::string, std::function<ArgsT(const CallExpr *)>...>>(
13741374
SourceName, BaseCreator, IsArrow, MemberName,
13751375
std::forward<std::function<ArgsT(const CallExpr *)>>(ArgsCreator)...);
13761376
}
@@ -2252,7 +2252,7 @@ const std::string MipmapNeedBindlessImage =
22522252
#define MEMBER_CALL_FACTORY_ENTRY(FuncName, ...) \
22532253
std::make_pair(FuncName, createMemberCallExprRewriterFactory<false>( \
22542254
FuncName, __VA_ARGS__)),
2255-
#define MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(FuncName, ...) \
2255+
#define MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(FuncName, ...) \
22562256
std::make_pair(FuncName, createMemberCallExprRewriterFactory<true>( \
22572257
FuncName, __VA_ARGS__)),
22582258
#define ARRAYSUBSCRIPT_EXPR_FACTORY_ENTRY(FuncName, ...) \

clang/lib/DPCT/RulesLang/APINamesComplex.inc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,7 @@ BINARY_OP_FACTORY_ENTRY("cuCfmaf", BinaryOperatorKind::BO_Add,
9595
makeCallArgCreatorWithCall(1)),
9696
makeCallArgCreatorWithCall(2))
9797

98-
MEMBER_CALL_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0), false, "convert<float>")
99-
MEMBER_CALL_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0), false, "convert<double>")
98+
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0),
99+
false, "convert<float>")
100+
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0),
101+
false, "convert<double>")

clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp

Lines changed: 17 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -288,75 +288,24 @@ std::optional<std::string> MathTypeCastRewriter::rewrite() {
288288
const StringRef &FuncName = SourceCalleeName;
289289
std::string ReplStr;
290290
llvm::raw_string_ostream OS(ReplStr);
291-
292291
auto MigratedArg0 = getMigratedArgWithExtraParens(0);
293-
if (FuncName == "__float22half2_rn") {
294-
OS << MigratedArg0
295-
<< ".convert<" + MapNames::getClNamespace() + "half, " +
296-
MapNames::getClNamespace() + "rounding_mode::rte>()";
297-
} else if (FuncName == "__float2half2_rn") {
298-
OS << MapNames::getClNamespace() + "float2{" << MigratedArg0 << ","
299-
<< MigratedArg0
300-
<< "}.convert<" + MapNames::getClNamespace() + "half, " +
301-
MapNames::getClNamespace() + "rounding_mode::rte>()";
302-
} else if (FuncName == "__floats2half2_rn") {
303-
auto MigratedArg1 = getMigratedArg(1);
304-
OS << MapNames::getClNamespace() + "float2{" << MigratedArg0 << ","
305-
<< MigratedArg1
306-
<< "}.convert<" + MapNames::getClNamespace() + "half, " +
307-
MapNames::getClNamespace() + "rounding_mode::rte>()";
308-
} else if (FuncName == "__half22float2") {
309-
OS << MigratedArg0
310-
<< ".convert<float, " + MapNames::getClNamespace() +
311-
"rounding_mode::automatic>()";
312-
} else if (FuncName == "__half2half2") {
313-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << ","
314-
<< MigratedArg0 << "}";
315-
} else if (FuncName == "__halves2half2") {
316-
auto MigratedArg1 = getMigratedArg(1);
317-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << ","
318-
<< MigratedArg1 << "}";
319-
} else if (FuncName == "__high2half") {
320-
OS << MigratedArg0 << "[0]";
321-
} else if (FuncName == "__high2half2") {
322-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[0], "
323-
<< MigratedArg0 << "[0]}";
324-
} else if (FuncName == "__highs2half2") {
325-
auto MigratedArg1 = getMigratedArgWithExtraParens(1);
326-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[0], "
327-
<< MigratedArg1 << "[0]}";
328-
} else if (FuncName == "__low2half") {
329-
OS << MigratedArg0 << "[1]";
330-
} else if (FuncName == "__low2half2") {
331-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], "
332-
<< MigratedArg0 << "[1]}";
333-
} else if (FuncName == "__lowhigh2highlow") {
334-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], "
335-
<< MigratedArg0 << "[0]}";
336-
} else if (FuncName == "__lows2half2") {
337-
auto MigratedArg1 = getMigratedArgWithExtraParens(1);
338-
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], "
339-
<< MigratedArg1 << "[1]}";
340-
} else {
341-
//__half2short_rd and __half2float
342-
static SSMap TypeMap{{"ll", "long long"},
343-
{"ull", "unsigned long long"},
344-
{"ushort", "unsigned short"},
345-
{"uint", "unsigned int"},
346-
{"half", MapNames::getClNamespace() + "half"}};
347-
std::string RoundingMode;
348-
if (FuncName[FuncName.size() - 3] == '_')
349-
RoundingMode = FuncName.substr(FuncName.size() - 2).str();
350-
auto FN = FuncName.substr(2, FuncName.find('_', 2) - 2).str();
351-
auto Types = split(FN, '2');
352-
assert(Types.size() == 2);
353-
MapNames::replaceName(TypeMap, Types[0]);
354-
MapNames::replaceName(TypeMap, Types[1]);
355-
OS << MapNames::getClNamespace() + "vec<" << Types[0] << ", 1>{"
356-
<< MigratedArg0 << "}.convert<" << Types[1]
357-
<< ", " + MapNames::getClNamespace() + "rounding_mode::"
358-
<< RoundingModeMap[RoundingMode] << ">()[0]";
359-
}
292+
static SSMap TypeMap{{"ll", "long long"},
293+
{"ull", "unsigned long long"},
294+
{"ushort", "unsigned short"},
295+
{"uint", "unsigned int"},
296+
{"half", MapNames::getClNamespace() + "half"}};
297+
std::string RoundingMode;
298+
if (FuncName[FuncName.size() - 3] == '_')
299+
RoundingMode = FuncName.substr(FuncName.size() - 2).str();
300+
auto FN = FuncName.substr(2, FuncName.find('_', 2) - 2).str();
301+
auto Types = split(FN, '2');
302+
assert(Types.size() == 2);
303+
MapNames::replaceName(TypeMap, Types[0]);
304+
MapNames::replaceName(TypeMap, Types[1]);
305+
OS << MapNames::getClNamespace() + "vec<" << Types[0] << ", 1>{"
306+
<< MigratedArg0 << "}.convert<" << Types[1]
307+
<< ", " + MapNames::getClNamespace() + "rounding_mode::"
308+
<< RoundingModeMap[RoundingMode] << ">()[0]";
360309
OS.flush();
361310
return ReplStr;
362311
}

clang/lib/DPCT/RulesLang/Math/RewriterHalfPrecisionConversionAndDataMovement.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
3030
CALL(MapNames::getClNamespace() +
3131
"ext::intel::math::float2half_rn",
3232
MEMBER_CALL(ARG(0), false, "y"))))),
33-
MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(
33+
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(
3434
"__float22half2_rn", ARG(0), false,
3535
"convert<" + MapNames::getClNamespace() + "half, " +
3636
MapNames::getClNamespace() + "rounding_mode::rte>"))
@@ -168,7 +168,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
168168
CALL(MapNames::getClNamespace() +
169169
"ext::intel::math::half2float",
170170
MEMBER_CALL(ARG(0), false, "y"))))),
171-
MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(
171+
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(
172172
"__half22float2", ARG(0), false,
173173
"convert<float, " + MapNames::getClNamespace() +
174174
"rounding_mode::automatic>"))

clang/test/dpct/complex.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ int main() {
290290
auto a24 = COMPLEX_F_FMA(f1, f2, f3);
291291
r = r && check(a24, expect, index);
292292

293-
// CHECK: f1 = d1.convert<float>();
293+
// CHECK: f1 = d1.template convert<float>();
294294
f1 = cuComplexDoubleToFloat(d1);
295-
// CHECK: d1 = f1.convert<double>();
295+
// CHECK: d1 = f1.template convert<double>();
296296
d1 = cuComplexFloatToDouble(f1);
297297

298298
int *result = nullptr;

clang/test/dpct/math/cuda-math-need-paren.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using namespace std;
88

99
void __global__ kernel() {
1010
half2 h2;
11-
// CHECK: (h2 + h2).convert<float, sycl::rounding_mode::automatic>();
11+
// CHECK: (h2 + h2).template convert<float, sycl::rounding_mode::automatic>();
1212
__half22float2(__hadd2(h2, h2));
1313
}
1414

clang/test/dpct/math/half/half.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ __global__ void kernelFuncHalfConversion() {
1515
unsigned u;
1616
unsigned long long ull;
1717
unsigned short us;
18-
// CHECK: h2 = f2.convert<sycl::half, sycl::rounding_mode::rte>();
18+
// CHECK: h2 = f2.template convert<sycl::half, sycl::rounding_mode::rte>();
1919
h2 = __float22half2_rn(f2);
2020
// CHECK: h = sycl::vec<float, 1>(f).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
2121
h = __float2half(f);
@@ -31,7 +31,7 @@ __global__ void kernelFuncHalfConversion() {
3131
h = __float2half_rz(f);
3232
// CHECK: h2 = sycl::float2(f, f).convert<sycl::half, sycl::rounding_mode::rte>();
3333
h2 = __floats2half2_rn(f, f);
34-
// CHECK: f2 = h2.convert<float, sycl::rounding_mode::automatic>();
34+
// CHECK: f2 = h2.template convert<float, sycl::rounding_mode::automatic>();
3535
f2 = __half22float2(h2);
3636
// CHECK: f = sycl::vec<sycl::half, 1>(h).convert<float, sycl::rounding_mode::automatic>()[0];
3737
f = __half2float(h);

clang/test/dpct/query_api_mapping/NoLib/test.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@
8989
// CUCOMPLEXDOUBLETOFLOAT: CUDA API:
9090
// CUCOMPLEXDOUBLETOFLOAT-NEXT: cuComplexDoubleToFloat(c /*cuDoubleComplex*/);
9191
// CUCOMPLEXDOUBLETOFLOAT-NEXT: Is migrated to:
92-
// CUCOMPLEXDOUBLETOFLOAT-NEXT: c.convert<float>();
92+
// CUCOMPLEXDOUBLETOFLOAT-NEXT: c.template convert<float>();
9393

9494
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cuComplexFloatToDouble | FileCheck %s -check-prefix=CUCOMPLEXFLOATTODOUBLE
9595
// CUCOMPLEXFLOATTODOUBLE: CUDA API:
9696
// CUCOMPLEXFLOATTODOUBLE-NEXT: cuComplexFloatToDouble(c /*cuFloatComplex*/);
9797
// CUCOMPLEXFLOATTODOUBLE-NEXT: Is migrated to:
98-
// CUCOMPLEXFLOATTODOUBLE-NEXT: c.convert<double>();
98+
// CUCOMPLEXFLOATTODOUBLE-NEXT: c.template convert<double>();
9999

100100
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cuConj | FileCheck %s -check-prefix=CUCONJ
101101
// CUCONJ: CUDA API:

0 commit comments

Comments
 (0)