Skip to content

Commit 2039135

Browse files
authored
[SYCLomatic] Fix the migration of __hmul2 and __hsub2 when --use-dpcpp-extensions=intel_device_math is specified (#2728)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent d40655b commit 2039135

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

clang/lib/DPCT/RulesLang/Math/RewriterHalf2ArithmeticFunctions.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -261,22 +261,30 @@ RewriterMap dpct::createHalf2ArithmeticFunctionsRewriterMap() {
261261
makeCallArgCreatorWithCall(2)),
262262
LITERAL("{0.f, 0.f}"), LITERAL("{1.f, 1.f}"))))))
263263
// __hmul2
264-
MATH_API_REWRITER_DEVICE(
265-
"__hmul2",
266-
MATH_API_DEVICE_NODES(
267-
EMPTY_FACTORY_ENTRY("__hmul2"),
268-
MATH_API_SPECIFIC_ELSE_EMU(
269-
CheckArgType(0, "__half2"),
264+
MATH_API_REWRITER_DEVICE_OVERLOAD(
265+
CheckArgType(0, "__half2"),
266+
MATH_API_REWRITERS_V2(
267+
"__hmul2",
268+
MATH_API_REWRITER_PAIR(
269+
math::Tag::math_libdevice,
270270
HEADER_INSERT_FACTORY(
271271
HeaderType::HT_SYCL_Math,
272272
CALL_FACTORY_ENTRY("__hmul2",
273273
CALL(MapNames::getClNamespace() +
274274
"ext::intel::math::hmul2",
275275
ARG(0), ARG(1))))),
276-
EMPTY_FACTORY_ENTRY("__hmul2"),
277-
BINARY_OP_FACTORY_ENTRY("__hmul2", BinaryOperatorKind::BO_Mul,
278-
makeCallArgCreatorWithCall(0),
279-
makeCallArgCreatorWithCall(1))))
276+
MATH_API_REWRITER_PAIR(
277+
math::Tag::emulation,
278+
BINARY_OP_FACTORY_ENTRY("__hmul2", BinaryOperatorKind::BO_Mul,
279+
makeCallArgCreatorWithCall(0),
280+
makeCallArgCreatorWithCall(1)))),
281+
MATH_API_REWRITERS_V2(
282+
"__hmul2",
283+
MATH_API_REWRITER_PAIR(
284+
math::Tag::emulation,
285+
BINARY_OP_FACTORY_ENTRY("__hmul2", BinaryOperatorKind::BO_Mul,
286+
makeCallArgCreatorWithCall(0),
287+
makeCallArgCreatorWithCall(1)))))
280288
// __hmul2_rn
281289
MATH_API_REWRITER_DEVICE(
282290
"__hmul2_rn",
@@ -343,22 +351,30 @@ RewriterMap dpct::createHalf2ArithmeticFunctionsRewriterMap() {
343351
UNARY_OP_FACTORY_ENTRY("__hneg2", UnaryOperatorKind::UO_Minus,
344352
makeCallArgCreatorWithCall(0))))
345353
// __hsub2
346-
MATH_API_REWRITER_DEVICE(
347-
"__hsub2",
348-
MATH_API_DEVICE_NODES(
349-
EMPTY_FACTORY_ENTRY("__hsub2"),
350-
MATH_API_SPECIFIC_ELSE_EMU(
351-
CheckArgType(0, "__half2"),
354+
MATH_API_REWRITER_DEVICE_OVERLOAD(
355+
CheckArgType(0, "__half2"),
356+
MATH_API_REWRITERS_V2(
357+
"__hsub2",
358+
MATH_API_REWRITER_PAIR(
359+
math::Tag::math_libdevice,
352360
HEADER_INSERT_FACTORY(
353361
HeaderType::HT_SYCL_Math,
354362
CALL_FACTORY_ENTRY("__hsub2",
355363
CALL(MapNames::getClNamespace() +
356364
"ext::intel::math::hsub2",
357365
ARG(0), ARG(1))))),
358-
EMPTY_FACTORY_ENTRY("__hsub2"),
359-
BINARY_OP_FACTORY_ENTRY("__hsub2", BinaryOperatorKind::BO_Sub,
360-
makeCallArgCreatorWithCall(0),
361-
makeCallArgCreatorWithCall(1))))
366+
MATH_API_REWRITER_PAIR(
367+
math::Tag::emulation,
368+
BINARY_OP_FACTORY_ENTRY("__hsub2", BinaryOperatorKind::BO_Sub,
369+
makeCallArgCreatorWithCall(0),
370+
makeCallArgCreatorWithCall(1)))),
371+
MATH_API_REWRITERS_V2(
372+
"__hsub2",
373+
MATH_API_REWRITER_PAIR(
374+
math::Tag::emulation,
375+
BINARY_OP_FACTORY_ENTRY("__hsub2", BinaryOperatorKind::BO_Sub,
376+
makeCallArgCreatorWithCall(0),
377+
makeCallArgCreatorWithCall(1)))))
362378
// __hsub2_rn
363379
MATH_API_REWRITER_DEVICE(
364380
"__hsub2_rn",

clang/test/dpct/math/bfloat16/bfloat16_ext.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ __global__ void kernelFuncBfloat162Arithmetic() {
1111
__nv_bfloat162 bf162, bf162_1, bf162_2;
1212
// CHECK: bf162 = bf162_1 / bf162_2;
1313
bf162 = __h2div(bf162_1, bf162_2);
14+
// CHECK: bf162 = bf162_1 * bf162_2;
15+
bf162 = __hmul2(bf162_1, bf162_2);
16+
// CHECK: bf162 = bf162_1 - bf162_2;
17+
bf162 = __hsub2(bf162_1, bf162_2);
1418
}
1519

1620
__global__ void kernelFuncBfloat16Comparison() {

0 commit comments

Comments
 (0)