Skip to content

Commit 91dfbb3

Browse files
authored
[SYCLomatic] Add an empty rewriter in case no migration required for math API (#2846)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent a4db727 commit 91dfbb3

File tree

5 files changed

+37
-31
lines changed

5 files changed

+37
-31
lines changed

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,14 @@ class NoRewriteFuncNameRewriter : public CallExprRewriter {
586586
std::optional<std::string> rewrite() override { return NewFuncName; }
587587
};
588588

589+
// No replacement generated
590+
class EmptyRewriter : public CallExprRewriter {
591+
public:
592+
EmptyRewriter(const CallExpr *, StringRef, StringRef)
593+
: CallExprRewriter(Call, SourceCalleeName) {}
594+
std::optional<std::string> rewrite() override { return std::nullopt; }
595+
};
596+
589597
struct ThrustFunctor {
590598
ThrustFunctor(const clang::Expr *E) : E(E) {}
591599
const clang::Expr *E;

clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ using WarpFunctionRewriterFactory =
3131
CallExprRewriterFactory<WarpFunctionRewriter, std::string>;
3232
using NoRewriteFuncNameRewriterFactory =
3333
CallExprRewriterFactory<NoRewriteFuncNameRewriter, std::string>;
34+
using EmptyRewriterFactory =
35+
CallExprRewriterFactory<EmptyRewriter, std::string>;
3436

3537
/// Base class for rewriting math function calls
3638
class MathCallExprRewriter : public FuncCallExprRewriter {
@@ -349,10 +351,9 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
349351
: Name(Name), MathAPIRewriters(MathAPIRewritersInput) {
350352
NoRewriteRewriter = std::make_pair(
351353
TrueFunctor,
352-
std::make_pair(Name,
353-
std::dynamic_pointer_cast<CallExprRewriterFactoryBase>(
354-
std::make_shared<NoRewriteFuncNameRewriterFactory>(
355-
Name, Name))));
354+
std::make_pair(
355+
Name, std::dynamic_pointer_cast<CallExprRewriterFactoryBase>(
356+
std::make_shared<EmptyRewriterFactory>(Name, Name))));
356357
}
357358
// a. Host API priority:
358359
// 1. host_perf
@@ -561,22 +562,19 @@ createMathAPIRewriterDevice(
561562
math::IsDefinedInCUDA(),
562563
std::move(createMathAPIRewriterDeviceImpl(Name, PerfPred, DevicePerf,
563564
DeviceNodes)),
564-
{Name,
565-
std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
565+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
566566
createConditionalFactory(
567567
math::IsUnresolvedLookupExpr,
568568
createConditionalFactory(
569569
math::IsDirectCallerPureDevice,
570570
std::move(createMathAPIRewriterDeviceImpl(
571571
Name, PerfPred, DevicePerf, DeviceNodes)),
572-
{Name,
573-
std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
572+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
574573
createConditionalFactory(
575574
math::IsDefinedInCUDA(),
576575
std::move(createMathAPIRewriterDeviceImpl(
577576
Name, PerfPred, DevicePerf, DeviceNodes)),
578-
{Name, std::make_shared<NoRewriteFuncNameRewriterFactory>(
579-
Name, Name)})));
577+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)})));
580578
}
581579

582580
inline std::pair<std::string, std::shared_ptr<CallExprRewriterFactoryBase>>
@@ -590,21 +588,17 @@ createMathAPIRewriterDevice(
590588
createConditionalFactory(
591589
math::IsDefinedInCUDA(),
592590
std::move(createMathAPIRewriterDeviceImpl(Name, DeviceNodes)),
593-
{Name,
594-
std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
591+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
595592
createConditionalFactory(
596593
math::IsUnresolvedLookupExpr,
597594
createConditionalFactory(
598595
math::IsDirectCallerPureDevice,
599596
std::move(createMathAPIRewriterDeviceImpl(Name, DeviceNodes)),
600-
{Name,
601-
std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)}),
597+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)}),
602598
createConditionalFactory(
603599
math::IsDefinedInCUDA(),
604-
std::move(
605-
createMathAPIRewriterDeviceImpl(Name, DeviceNodes)),
606-
{Name, std::make_shared<NoRewriteFuncNameRewriterFactory>(
607-
Name, Name)})));
600+
std::move(createMathAPIRewriterDeviceImpl(Name, DeviceNodes)),
601+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)})));
608602
}
609603

610604
template <class T>
@@ -620,13 +614,11 @@ createMathAPIRewriterExperimentalBfloat16(
620614
if (math::useExtBFloat16Math() && Rewriter1.second)
621615
return createConditionalFactory(
622616
math::IsDefinedInCUDA(), std::move(Rewriter1),
623-
{Name,
624-
std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)});
617+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)});
625618
if (Rewriter2.second)
626619
return createConditionalFactory(
627620
math::IsDefinedInCUDA(), std::move(Rewriter2),
628-
{Name,
629-
std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)});
621+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)});
630622
}
631623
// report unsupport
632624
return std::pair<std::string, std::shared_ptr<CallExprRewriterFactoryBase>>(
@@ -643,7 +635,7 @@ createMathAPIRewriterHost(
643635
T) {
644636
return createConditionalFactory(
645637
math::IsDefinedInCUDA(), std::move(HostNormal),
646-
{Name, std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)});
638+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)});
647639
}
648640

649641
template <class T>
@@ -660,7 +652,7 @@ createMathAPIRewriterHost(
660652
math::IsDefinedInCUDA(),
661653
createConditionalFactory(makeCheckAnd(math::IsPerf, PerfPred),
662654
std::move(HostPerf), std::move(HostNormal)),
663-
{Name, std::make_shared<NoRewriteFuncNameRewriterFactory>(Name, Name)});
655+
{Name, std::make_shared<EmptyRewriterFactory>(Name, Name)});
664656
}
665657

666658
template <bool IsDouble> std::string getPiString() {

clang/test/dpct/complex.cu

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
// CHECK-NEXT: #define COMPLEX_D_DIV(a, b) dpct::cdiv<double>(a, b)
1818
// CHECK-NEXT: #define COMPLEX_D_FMA(a, b, c) dpct::cmul<double>(a, b) + c
1919
// CHECK-NEXT: #define COMPLEX_D_ABS(a) dpct::cabs<double>(a)
20-
// CHECK-NEXT: /*
21-
// CHECK-NEXT: DPCT1064:{{[0-9]+}}: Migrated fabs call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code.
22-
// CHECK-NEXT: */
2320
// CHECK-NEXT: #define COMPLEX_D_ABS1(a) (sycl::fabs((a).x()) + sycl::fabs((a).y()))
2421
// CHECK-NEXT: #define COMPLEX_D_CONJ(a) dpct::conj<double>(a)
2522
#define COMPLEX_D_MAKE(r,i) make_cuDoubleComplex(r, i)
@@ -47,9 +44,6 @@
4744
// CHECK-NEXT: #define COMPLEX_F_DIV(a, b) dpct::cdiv<float>(a, b)
4845
// CHECK-NEXT: #define COMPLEX_F_FMA(a, b, c) dpct::cmul<float>(a, b) + c
4946
// CHECK-NEXT: #define COMPLEX_F_ABS(a) dpct::cabs<float>(a)
50-
// CHECK-NEXT: /*
51-
// CHECK-NEXT: DPCT1064:{{[0-9]+}}: Migrated fabsf call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code.
52-
// CHECK-NEXT: */
5347
// CHECK-NEXT: #define COMPLEX_F_ABS1(a) (sycl::fabs((a).x()) + sycl::fabs((a).y()))
5448
// CHECK-NEXT: #define COMPLEX_F_CONJ(a) dpct::conj<float>(a)
5549
#define COMPLEX_F_MAKE(r,i) make_cuFloatComplex(r, i)

clang/test/dpct/cuda_arch_test/math_in_cuda_arch.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
// CHECK: float f(float x) {
66
// CHECK-EMPTY:
7-
// CHECK-NEXT: return expf(x);
7+
// CHECK-NEXT: return ::expf(x);
88
// CHECK-EMPTY:
99
// CHECK-NEXT: }
1010
// CHECK-NEXT: float f_host_ct0(float x) {

clang/test/dpct/math/cuda-math-intrinsics.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3584,3 +3584,15 @@ template <typename T> class AAA_st {
35843584
return 0;
35853585
}
35863586
};
3587+
3588+
template <typename t1> void saturate() { printf("default saturate\n"); }
3589+
template <> void saturate<int>() { printf("int saturate\n"); }
3590+
3591+
// CHECK: void foo11() {
3592+
// CHECK-NEXT: saturate<int>();
3593+
// CHECK-NEXT: saturate<float>();
3594+
// CHECK-NEXT: }
3595+
void foo11() {
3596+
saturate<int>();
3597+
saturate<float>();
3598+
}

0 commit comments

Comments
 (0)