Skip to content

Commit bd4c37e

Browse files
authored
[SYCLomatic] Use rewriter to re-impl migration for API: fmin,fmax,fabs,fminf,fmaxf,fabsf (#2723)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent a558dae commit bd4c37e

11 files changed

+179
-42
lines changed

clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,16 @@ template <class SubExprT> class CastIfNotSameExprPrinter {
148148
const Expr *InputArg = SubExpr->IgnoreUnlessSpelledInSource();
149149
clang::QualType ArgType = InputArg->getType().getCanonicalType();
150150
ArgType.removeLocalFastQualifiers(clang::Qualifiers::CVRMask);
151+
bool NeedParen = false;
151152
if (ArgType.getAsString() != TypeInfo) {
153+
NeedParen = needExtraParens(SubExpr);
152154
Stream << "(" << TypeInfo << ")";
153155
}
156+
if (NeedParen)
157+
Stream << "(";
154158
dpct::print(Stream, SubExpr);
159+
if (NeedParen)
160+
Stream << ")";
155161
}
156162
};
157163

clang/lib/DPCT/RulesLang/APINamesMath.inc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ ENTRY_RENAMED_SINGLE("exp10f", MapNames::getClNamespace(false, true) + "exp10")
6969
ENTRY_RENAMED_SINGLE("exp2f", MapNames::getClNamespace(false, true) + "exp2")
7070
ENTRY_REWRITE("expf")
7171
ENTRY_RENAMED_SINGLE("expm1f", MapNames::getClNamespace(false, true) + "expm1")
72-
ENTRY_RENAMED_SINGLE("fabsf", MapNames::getClNamespace(false, true) + "fabs")
72+
ENTRY_REWRITE("fabsf")
7373
ENTRY_RENAMED_SINGLE("fdimf", MapNames::getClNamespace(false, true) + "fdim")
7474
ENTRY_RENAMED_SINGLE("floorf", MapNames::getClNamespace(false, true) + "floor")
7575
ENTRY_RENAMED_SINGLE("fmaf", MapNames::getClNamespace(false, true) + "fma")
76-
ENTRY_RENAMED_SINGLE("fmaxf", MapNames::getClNamespace(false, true) + "fmax")
77-
ENTRY_RENAMED_SINGLE("fminf", MapNames::getClNamespace(false, true) + "fmin")
76+
ENTRY_REWRITE("fmaxf")
77+
ENTRY_REWRITE("fminf")
7878
ENTRY_RENAMED_SINGLE("fmodf", MapNames::getClNamespace(false, true) + "fmod")
7979
ENTRY_RENAMED_SINGLE("hypotf", MapNames::getClNamespace(false, true) + "hypot")
8080
ENTRY_RENAMED_SINGLE("ilogbf", MapNames::getClNamespace(false, true) + "ilogb")
@@ -151,12 +151,12 @@ ENTRY_RENAMED_DOUBLE("exp10", MapNames::getClNamespace(false, true) + "exp10")
151151
ENTRY_RENAMED_DOUBLE("exp2", MapNames::getClNamespace(false, true) + "exp2")
152152
ENTRY_RENAMED_DOUBLE("exp", MapNames::getClNamespace(false, true) + "exp")
153153
ENTRY_RENAMED_DOUBLE("expm1", MapNames::getClNamespace(false, true) + "expm1")
154-
ENTRY_RENAMED_DOUBLE("fabs", MapNames::getClNamespace(false, true) + "fabs")
154+
ENTRY_REWRITE("fabs")
155155
ENTRY_RENAMED_DOUBLE("fdim", MapNames::getClNamespace(false, true) + "fdim")
156156
ENTRY_RENAMED_DOUBLE("floor", MapNames::getClNamespace(false, true) + "floor")
157157
ENTRY_RENAMED_DOUBLE("fma", MapNames::getClNamespace(false, true) + "fma")
158-
ENTRY_RENAMED_DOUBLE("fmax", MapNames::getClNamespace(false, true) + "fmax")
159-
ENTRY_RENAMED_DOUBLE("fmin", MapNames::getClNamespace(false, true) + "fmin")
158+
ENTRY_REWRITE("fmax")
159+
ENTRY_REWRITE("fmin")
160160
ENTRY_RENAMED_DOUBLE("fmod", MapNames::getClNamespace(false, true) + "fmod")
161161
ENTRY_RENAMED_DOUBLE("hypot", MapNames::getClNamespace(false, true) + "hypot")
162162
ENTRY_RENAMED_DOUBLE("ilogb", MapNames::getClNamespace(false, true) + "ilogb")
@@ -237,8 +237,6 @@ ENTRY_REWRITE("__funnelshift_l")
237237
ENTRY_REWRITE("__funnelshift_lc")
238238
ENTRY_REWRITE("__funnelshift_r")
239239
ENTRY_REWRITE("__funnelshift_rc")
240-
// Used to add header file "<cmath>" into the file that calls "fabs".
241-
ENTRY_RENAMED("fabs", "::fabs")
242240

243241
// Renamed Version 2
244242
ENTRY_RENAMED("__mul64hi", MapNames::getClNamespace(false, true) + "mul_hi")

clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,43 @@ inline auto UseBFloat16 = [](const CallExpr *C) -> bool {
189189
return DpctGlobalInfo::useBFloat16();
190190
};
191191

192+
inline auto IsDirectCallerPureDevice = [](const CallExpr *C) -> bool {
193+
auto ContextFD = getImmediateOuterFuncDecl(C);
194+
while (auto LE = getImmediateOuterLambdaExpr(ContextFD)) {
195+
ContextFD = getImmediateOuterFuncDecl(LE);
196+
}
197+
if (!ContextFD)
198+
return false;
199+
if ((ContextFD->getAttr<CUDADeviceAttr>() &&
200+
!ContextFD->getAttr<CUDAHostAttr>()) ||
201+
ContextFD->getAttr<CUDAGlobalAttr>()) {
202+
return true;
203+
}
204+
return false;
205+
};
206+
207+
inline auto IsDirectCallerPureHost = [](const CallExpr *C) -> bool {
208+
auto ContextFD = getImmediateOuterFuncDecl(C);
209+
while (auto LE = getImmediateOuterLambdaExpr(ContextFD)) {
210+
ContextFD = getImmediateOuterFuncDecl(LE);
211+
}
212+
if (!ContextFD)
213+
return false;
214+
if (!ContextFD->getAttr<CUDADeviceAttr>() &&
215+
!ContextFD->getAttr<CUDAGlobalAttr>()) {
216+
return true;
217+
}
218+
return false;
219+
};
220+
192221
inline auto IsPureHost = [](const CallExpr *C) -> bool {
193222
const FunctionDecl *FD = C->getDirectCallee();
194223
if (!FD)
195224
return false;
225+
if (!IsDirectCallerPureHost(C))
226+
return false;
196227
if (!(FD->hasAttr<CUDADeviceAttr>()))
197228
return true;
198-
199229
SourceLocation DeclLoc =
200230
dpct::DpctGlobalInfo::getSourceManager().getExpansionLoc(
201231
FD->getLocation());
@@ -209,22 +239,12 @@ inline auto IsPureHost = [](const CallExpr *C) -> bool {
209239
}
210240
return false;
211241
};
212-
inline auto IsPureDevice = makeCheckAnd(
213-
HasDirectCallee(),
214-
makeCheckAnd(IsDirectCalleeHasAttribute<CUDADeviceAttr>(),
215-
makeCheckNot(IsDirectCalleeHasAttribute<CUDAHostAttr>())));
216-
217-
inline auto IsDirectCallerPureDevice = [](const CallExpr *C) -> bool {
218-
auto ContextFD = getImmediateOuterFuncDecl(C);
219-
while (auto LE = getImmediateOuterLambdaExpr(ContextFD)) {
220-
ContextFD = getImmediateOuterFuncDecl(LE);
221-
}
222-
if (!ContextFD)
242+
inline auto IsPureDevice = [](const CallExpr *C) -> bool {
243+
if (!HasDirectCallee()(C))
223244
return false;
224-
if (ContextFD->getAttr<CUDADeviceAttr>() &&
225-
!ContextFD->getAttr<CUDAHostAttr>()) {
245+
if (IsDirectCalleeHasAttribute<CUDADeviceAttr>()(C) &&
246+
!IsDirectCalleeHasAttribute<CUDAHostAttr>()(C))
226247
return true;
227-
}
228248
return false;
229249
};
230250
inline auto IsUnresolvedLookupExpr = [](const CallExpr *C) -> bool {
@@ -344,8 +364,9 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
344364
// 4. math_libdevice
345365
// 5. device_std
346366
// c. Host and device
347-
// 1. emulation
348-
// 2. unsupported_warning
367+
// 1. host_device
368+
// 2. emulation
369+
// 3. unsupported_warning
349370
std::shared_ptr<CallExprRewriter> create(const CallExpr *C) const override {
350371
if (math::IsPureHost(C)) {
351372
// HOST
@@ -355,6 +376,8 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
355376
return HostPerfRewriter.value().second.second->create(C);
356377
if (HostNormalRewriter && HostNormalRewriter.value().first(C))
357378
return HostNormalRewriter.value().second.second->create(C);
379+
} else {
380+
return NoRewriteRewriter.value().second.second->create(C);
358381
}
359382
} else {
360383
// DEVICE
@@ -378,12 +401,12 @@ class MathRewriterFactory final : public CallExprRewriterFactoryBase {
378401
}
379402

380403
// Host and device
381-
if (EmulationRewriter && EmulationRewriter.value().first(C))
382-
return EmulationRewriter.value().second.second->create(C);
383-
384404
if (HostDeviceRewriter && HostDeviceRewriter.value().first(C))
385405
return HostDeviceRewriter.value().second.second->create(C);
386406

407+
if (EmulationRewriter && EmulationRewriter.value().first(C))
408+
return EmulationRewriter.value().second.second->create(C);
409+
387410
if (UnsupportedWarningRewriter &&
388411
UnsupportedWarningRewriter.value().first(C))
389412
return UnsupportedWarningRewriter.value().second.second->create(C);

clang/lib/DPCT/RulesLang/Math/RewriterDoublePrecisionMathematicalFunctions.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,67 @@ using namespace clang::dpct;
1212

1313
RewriterMap dpct::createDoublePrecisionMathematicalFunctionsRewriterMap() {
1414
return RewriterMap{
15+
// fabs
16+
MATH_API_REWRITER_DEVICE_OVERLOAD(
17+
CheckParamType(0, "float"),
18+
MATH_API_REWRITERS_V2(
19+
"fabs",
20+
MATH_API_REWRITER_PAIR(
21+
math::Tag::host_device,
22+
CALL_FACTORY_ENTRY(
23+
"fabs",
24+
CALL(MapNames::getClNamespace() + "fabs",
25+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(0)))))),
26+
MATH_API_REWRITERS_V2(
27+
"fabs",
28+
MATH_API_REWRITER_PAIR(
29+
math::Tag::host_device,
30+
CALL_FACTORY_ENTRY(
31+
"fabs",
32+
CALL(MapNames::getClNamespace() + "fabs",
33+
CAST_IF_NOT_SAME(makeLiteral("double"), ARG(0)))))))
34+
// fmax
35+
MATH_API_REWRITER_DEVICE_OVERLOAD(
36+
CheckParamType(0, "float"),
37+
MATH_API_REWRITERS_V2(
38+
"fmax",
39+
MATH_API_REWRITER_PAIR(
40+
math::Tag::host_device,
41+
CALL_FACTORY_ENTRY(
42+
"fmax",
43+
CALL(MapNames::getClNamespace() + "fmax",
44+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(0)),
45+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(1)))))),
46+
MATH_API_REWRITERS_V2(
47+
"fmax",
48+
MATH_API_REWRITER_PAIR(
49+
math::Tag::host_device,
50+
CALL_FACTORY_ENTRY(
51+
"fmax",
52+
CALL(MapNames::getClNamespace() + "fmax",
53+
CAST_IF_NOT_SAME(makeLiteral("double"), ARG(0)),
54+
CAST_IF_NOT_SAME(makeLiteral("double"), ARG(1)))))))
55+
// fmin
56+
MATH_API_REWRITER_DEVICE_OVERLOAD(
57+
CheckParamType(0, "float"),
58+
MATH_API_REWRITERS_V2(
59+
"fmin",
60+
MATH_API_REWRITER_PAIR(
61+
math::Tag::host_device,
62+
CALL_FACTORY_ENTRY(
63+
"fmin",
64+
CALL(MapNames::getClNamespace() + "fmin",
65+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(0)),
66+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(1)))))),
67+
MATH_API_REWRITERS_V2(
68+
"fmin",
69+
MATH_API_REWRITER_PAIR(
70+
math::Tag::host_device,
71+
CALL_FACTORY_ENTRY(
72+
"fmin",
73+
CALL(MapNames::getClNamespace() + "fmin",
74+
CAST_IF_NOT_SAME(makeLiteral("double"), ARG(0)),
75+
CAST_IF_NOT_SAME(makeLiteral("double"), ARG(1)))))))
1576
// cospi
1677
CALL_FACTORY_ENTRY(
1778
"cospi",

clang/lib/DPCT/RulesLang/Math/RewriterSTDFunctions.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,26 @@ RewriterMap dpct::createSTDFunctionsRewriterMap() {
3434
CALL_FACTORY_ENTRY(
3535
"std::abs",
3636
CALL(MapNames::getClNamespace(false, true) + "fabs",
37-
ARG(0))))))};
37+
ARG(0))))))
38+
// std::fabs
39+
MATH_API_REWRITERS_V2(
40+
"std::fabs",
41+
MATH_API_REWRITER_PAIR(
42+
math::Tag::host_normal,
43+
CALL_FACTORY_ENTRY("std::fabs", CALL("std::fabs", ARG(0)))),
44+
MATH_API_REWRITER_PAIR(
45+
math::Tag::device_normal,
46+
CALL_FACTORY_ENTRY(
47+
"std::fabs",
48+
CALL(MapNames::getClNamespace(false, true) + "fabs",
49+
ARG(0)))),
50+
MATH_API_REWRITER_PAIR(
51+
math::Tag::host_device,
52+
CALL_FACTORY_ENTRY(
53+
"std::fabs",
54+
CALL(MapNames::getClNamespace(false, true) + "fabs",
55+
ARG(0)))),
56+
MATH_API_REWRITER_PAIR(
57+
math::Tag::device_std,
58+
CALL_FACTORY_ENTRY("std::fabs", CALL("std::fabs", ARG(0)))))};
3859
}

clang/lib/DPCT/RulesLang/Math/RewriterSinglePrecisionMathematicalFunctions.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@ using namespace clang::dpct;
1212

1313
RewriterMap dpct::createSinglePrecisionMathematicalFunctionsRewriterMap() {
1414
return RewriterMap{
15+
// fabsf
16+
MATH_API_REWRITERS_V2(
17+
"fabsf", MATH_API_REWRITER_PAIR(
18+
math::Tag::host_device,
19+
CALL_FACTORY_ENTRY(
20+
"fabsf", CALL(MapNames::getClNamespace() + "fabs",
21+
CAST_IF_NOT_SAME(makeLiteral("float"),
22+
ARG(0))))))
23+
// fmaxf
24+
MATH_API_REWRITERS_V2(
25+
"fmaxf",
26+
MATH_API_REWRITER_PAIR(
27+
math::Tag::host_device,
28+
CALL_FACTORY_ENTRY(
29+
"fmaxf",
30+
CALL(MapNames::getClNamespace() + "fmax",
31+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(0)),
32+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(1))))))
33+
// fminf
34+
MATH_API_REWRITERS_V2(
35+
"fminf",
36+
MATH_API_REWRITER_PAIR(
37+
math::Tag::host_device,
38+
CALL_FACTORY_ENTRY(
39+
"fminf",
40+
CALL(MapNames::getClNamespace() + "fmin",
41+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(0)),
42+
CAST_IF_NOT_SAME(makeLiteral("float"), ARG(1))))))
1543
// cyl_bessel_i0f
1644
MATH_API_REWRITER_DEVICE(
1745
"cyl_bessel_i0f",

clang/test/dpct/complex.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
// CHECK-NEXT: #define COMPLEX_D_DIV(a, b) dpct::cdiv<double>(a, b)
1818
// CHECK-NEXT: #define COMPLEX_D_FMA(a, b, c) dpct::cmul<double>(a, b) + c
1919
// CHECK-NEXT: #define COMPLEX_D_ABS(a) dpct::cabs<double>(a)
20+
// CHECK-NEXT: /*
21+
// CHECK-NEXT: DPCT1064:{{[0-9]+}}: Migrated fabs call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code.
22+
// CHECK-NEXT: */
2023
// CHECK-NEXT: #define COMPLEX_D_ABS1(a) (sycl::fabs((a).x()) + sycl::fabs((a).y()))
2124
// CHECK-NEXT: #define COMPLEX_D_CONJ(a) dpct::conj<double>(a)
2225
#define COMPLEX_D_MAKE(r,i) make_cuDoubleComplex(r, i)
@@ -44,6 +47,9 @@
4447
// CHECK-NEXT: #define COMPLEX_F_DIV(a, b) dpct::cdiv<float>(a, b)
4548
// CHECK-NEXT: #define COMPLEX_F_FMA(a, b, c) dpct::cmul<float>(a, b) + c
4649
// CHECK-NEXT: #define COMPLEX_F_ABS(a) dpct::cabs<float>(a)
50+
// CHECK-NEXT: /*
51+
// CHECK-NEXT: DPCT1064:{{[0-9]+}}: Migrated fabsf call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code.
52+
// CHECK-NEXT: */
4753
// CHECK-NEXT: #define COMPLEX_F_ABS1(a) (sycl::fabs((a).x()) + sycl::fabs((a).y()))
4854
// CHECK-NEXT: #define COMPLEX_F_CONJ(a) dpct::conj<float>(a)
4955
#define COMPLEX_F_MAKE(r,i) make_cuFloatComplex(r, i)

clang/test/dpct/device_call_in_math_call.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ __device__ float g(int width = 32)
88
return 1.0;
99
}
1010
__global__ void f() {
11-
// CHECK: sycl::fmax((float)(1.0), g(item_ct1, 16));
11+
// CHECK: sycl::fmax((float)1.0, g(item_ct1, 16));
1212
fmaxf(1.0, g(16));
1313
}

clang/test/dpct/macro_test.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ int foo14(){
746746
ALL2(const, ALL3(int2), *) lll;
747747
}
748748

749-
//CHECK: #define FABS(a) (sycl::fabs((float)((a).x())) + sycl::fabs((float)((a).y())))
749+
//CHECK: #define FABS(a) (sycl::fabs((a).x()) + sycl::fabs((a).y()))
750750
//CHECK-NEXT: static inline double foo16(const sycl::float2 &x) { return FABS(x); }
751751
#define FABS(a) (fabs((a).x) + fabs((a).y))
752752
__host__ __device__ static inline double foo16(const float2 &x) { return FABS(x); }

clang/test/dpct/math/cuda-math-intrinsics.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,7 +2449,7 @@ __device__ float foo2(float f, float g) {
24492449
}
24502450

24512451
// CHECK: int foo3(int i, int j) {
2452-
// CHECK-NEXT: return std::max(i, j) + std::min(i, j);
2452+
// CHECK-NEXT: return sycl::max(i, j) + sycl::min(i, j);
24532453
// CHECK-NEXT: }
24542454
__device__ int __host__ foo3(int i, int j) {
24552455
return max(i, j) + min(i, j);
@@ -2619,7 +2619,7 @@ __device__ void do_migration3() {
26192619
}
26202620
__host__ __device__ void do_migration4() {
26212621
int i, j;
2622-
// CHECK: std::max(i, j);
2622+
// CHECK: sycl::max(i, j);
26232623
max(i, j);
26242624
}
26252625
namespace t {

0 commit comments

Comments
 (0)