Skip to content

Commit 0f0c5d1

Browse files
authored
[SYCL][libclc][CUDA] Add --ffast-math support (#5801)
This patch allows the `--ffast-math` compiler flag to substitute the regular `genfloatf` math built-ins with their `::native` versions. Moreover, this patch completes the support of natives built-ins for `libclc/ptx-nvidiacl` connecting them with the `__nv_fast` functions present in libdevice. If a fast function is not available in libdevice the corresponding `nvvm` intrinsic is used.
1 parent adda31f commit 0f0c5d1

File tree

14 files changed

+187
-35
lines changed

14 files changed

+187
-35
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ BUILTIN(__nvvm_rcp_rm_ftz_f, "ff", "")
347347
BUILTIN(__nvvm_rcp_rm_f, "ff", "")
348348
BUILTIN(__nvvm_rcp_rp_ftz_f, "ff", "")
349349
BUILTIN(__nvvm_rcp_rp_f, "ff", "")
350+
BUILTIN(__nvvm_rcp_approx_f, "ff", "")
351+
BUILTIN(__nvvm_rcp_approx_ftz_f, "ff", "")
350352

351353
BUILTIN(__nvvm_rcp_rn_d, "dd", "")
352354
BUILTIN(__nvvm_rcp_rz_d, "dd", "")

clang/include/clang/Driver/Options.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1546,7 +1546,7 @@ def ffp_exception_behavior_EQ : Joined<["-"], "ffp-exception-behavior=">, Group<
15461546
MarshallingInfoEnum<LangOpts<"FPExceptionMode">, "FPE_Ignore">;
15471547
defm fast_math : BoolFOption<"fast-math",
15481548
LangOpts<"FastMath">, DefaultFalse,
1549-
PosFlag<SetTrue, [CC1Option], "Allow aggressive, lossy floating-point optimizations",
1549+
PosFlag<SetTrue, [CC1Option, CoreOption], "Allow aggressive, lossy floating-point optimizations",
15501550
[cl_fast_relaxed_math.KeyPath]>,
15511551
NegFlag<SetFalse>>;
15521552
def menable_unsafe_fp_math : Flag<["-"], "menable-unsafe-fp-math">, Flags<[CC1Option]>,

libclc/ptx-nvidiacl/libspirv/SOURCES

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ math/log2.cl
4343
math/logb.cl
4444
math/modf.cl
4545
math/native_cos.cl
46+
math/native_divide.cl
4647
math/native_exp.cl
4748
math/native_exp10.cl
4849
math/native_exp2.cl
4950
math/native_log.cl
5051
math/native_log10.cl
5152
math/native_log2.cl
5253
math/native_powr.cl
54+
math/native_recip.cl
5355
math/native_rsqrt.cl
5456
math/native_sin.cl
5557
math/native_sqrt.cl
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
#include "../../include/libdevice.h"
12+
#include <clcmacro.h>
13+
14+
#define __CLC_FUNCTION __spirv_ocl_native_divide
15+
#define __CLC_BUILTIN __nv_fast_fdivide
16+
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17+
#define __FLOAT_ONLY
18+
#include <math/binary_builtin.inc>

libclc/ptx-nvidiacl/libspirv/math/native_exp2.cl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88

99
#include <spirv/spirv.h>
1010

11-
#include "../../include/libdevice.h"
1211
#include <clcmacro.h>
1312

14-
#define __CLC_FUNCTION __spirv_ocl_native_exp2
15-
#define __CLC_BUILTIN __nv_exp2
16-
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
13+
extern int __clc_nvvm_reflect_ftz();
14+
15+
_CLC_DEF _CLC_OVERLOAD float __spirv_ocl_native_exp2(float x) {
16+
return (__clc_nvvm_reflect_ftz()) ? __nvvm_ex2_approx_ftz_f(x)
17+
: __nvvm_ex2_approx_f(x);
18+
}
19+
20+
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, float, __spirv_ocl_native_exp2,
21+
float)
1722

1823
#ifdef cl_khr_fp16
1924
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
@@ -39,9 +44,3 @@ _CLC_UNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __clc_native_exp2,
3944
#undef __USE_HALF_EXP2_APPROX
4045

4146
#endif // cl_khr_fp16
42-
43-
// Undef halfs before uncluding unary builtins, as they are handled above.
44-
#ifdef cl_khr_fp16
45-
#undef cl_khr_fp16
46-
#endif // cl_khr_fp16
47-
#include <math/unary_builtin.inc>
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include <spirv/spirv.h>
10+
11+
#include <clcmacro.h>
12+
13+
extern int __clc_nvvm_reflect_ftz();
14+
15+
_CLC_DEF _CLC_OVERLOAD float __spirv_ocl_native_recip(float x) {
16+
return (__clc_nvvm_reflect_ftz()) ? __nvvm_rcp_approx_ftz_f(x)
17+
: __nvvm_rcp_approx_f(x);
18+
}
19+
20+
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, float, __spirv_ocl_native_recip,
21+
float)

libclc/ptx-nvidiacl/libspirv/math/native_rsqrt.cl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
#include <spirv/spirv.h>
1010

11-
#include "../../include/libdevice.h"
1211
#include <clcmacro.h>
1312

14-
#define __CLC_FUNCTION __spirv_ocl_native_rsqrt
15-
#define __CLC_BUILTIN __nv_rsqrt
16-
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17-
#include <math/unary_builtin.inc>
13+
extern int __clc_nvvm_reflect_ftz();
14+
15+
_CLC_DEF _CLC_OVERLOAD float __spirv_ocl_native_rsqrt(float x) {
16+
return (__clc_nvvm_reflect_ftz()) ? __nvvm_rsqrt_approx_ftz_f(x)
17+
: __nvvm_rsqrt_approx_f(x);
18+
}
19+
20+
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, float, __spirv_ocl_native_rsqrt,
21+
float)

libclc/ptx-nvidiacl/libspirv/math/native_sin.cl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <clcmacro.h>
1313

1414
#define __CLC_FUNCTION __spirv_ocl_native_sin
15-
#define __CLC_BUILTIN __nv_sin
15+
#define __CLC_BUILTIN __nv_fast_sin
1616
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17+
#define __FLOAT_ONLY
1718
#include <math/unary_builtin.inc>

libclc/ptx-nvidiacl/libspirv/math/native_sqrt.cl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
#include <spirv/spirv.h>
1010

11-
#include "../../include/libdevice.h"
1211
#include <clcmacro.h>
1312

14-
#define __CLC_FUNCTION __spirv_ocl_native_sqrt
15-
#define __CLC_BUILTIN __nv_sqrt
16-
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17-
#include <math/unary_builtin.inc>
13+
extern int __clc_nvvm_reflect_ftz();
14+
15+
_CLC_DEF _CLC_OVERLOAD float __spirv_ocl_native_sqrt(float x) {
16+
return (__clc_nvvm_reflect_ftz()) ? __nvvm_sqrt_approx_ftz_f(x)
17+
: __nvvm_sqrt_approx_f(x);
18+
}
19+
20+
_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, float, __spirv_ocl_native_sqrt,
21+
float)

libclc/ptx-nvidiacl/libspirv/math/native_tan.cl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <clcmacro.h>
1313

1414
#define __CLC_FUNCTION __spirv_ocl_native_tan
15-
#define __CLC_BUILTIN __nv_tan
15+
#define __CLC_BUILTIN __nv_fast_tan
1616
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17+
#define __FLOAT_ONLY
1718
#include <math/unary_builtin.inc>

libclc/ptx-nvidiacl/libspirv/reflect.ll

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ define i32 @__clc_nvvm_reflect_arch() alwaysinline {
66
%reflect = call i32 @__nvvm_reflect(i8* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([12 x i8], [12 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
77
ret i32 %reflect
88
}
9+
10+
@str_ftz = private addrspace(1) constant [11 x i8] c"__CUDA_FTZ\00"
11+
12+
define i32 @__clc_nvvm_reflect_ftz() alwaysinline {
13+
%reflect = call i32 @__nvvm_reflect(i8* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([11 x i8], [11 x i8] addrspace(1)* @str_ftz, i32 0, i32 0) to i8*))
14+
ret i32 %reflect
15+
}

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,11 @@ let TargetPrefix = "nvvm" in {
935935
def int_nvvm_rcp_rp_f : GCCBuiltin<"__nvvm_rcp_rp_f">,
936936
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
937937

938+
def int_nvvm_rcp_approx_f : GCCBuiltin<"__nvvm_rcp_approx_f">,
939+
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
940+
def int_nvvm_rcp_approx_ftz_f : GCCBuiltin<"__nvvm_rcp_approx_ftz_f">,
941+
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
942+
938943
def int_nvvm_rcp_rn_d : GCCBuiltin<"__nvvm_rcp_rn_d">,
939944
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
940945
def int_nvvm_rcp_rz_d : GCCBuiltin<"__nvvm_rcp_rz_d">,

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,11 @@ def INT_NVVM_RCP_RP_FTZ_F : F_MATH_1<"rcp.rp.ftz.f32 \t$dst, $src0;",
10361036
def INT_NVVM_RCP_RP_F : F_MATH_1<"rcp.rp.f32 \t$dst, $src0;",
10371037
Float32Regs, Float32Regs, int_nvvm_rcp_rp_f>;
10381038

1039+
def INT_NVVM_RCP_APPROX_F : F_MATH_1<"rcp.approx.f32 \t$dst, $src0;",
1040+
Float32Regs, Float32Regs, int_nvvm_rcp_approx_f>;
1041+
def INT_NVVM_RCP_APPROX_FTZ_F : F_MATH_1<"rcp.approx.ftz.f32 \t$dst, $src0;",
1042+
Float32Regs, Float32Regs, int_nvvm_rcp_approx_ftz_f>;
1043+
10391044
def INT_NVVM_RCP_RN_D : F_MATH_1<"rcp.rn.f64 \t$dst, $src0;", Float64Regs,
10401045
Float64Regs, int_nvvm_rcp_rn_d>;
10411046
def INT_NVVM_RCP_RZ_D : F_MATH_1<"rcp.rz.f64 \t$dst, $src0;", Float64Regs,

sycl/include/CL/sycl/builtins.hpp

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ namespace sycl {
2525
namespace __sycl_std = __host_std;
2626
#endif
2727

28+
#ifdef __FAST_MATH__
29+
#define __FAST_MATH_GENFLOAT(T) \
30+
(detail::is_genfloatd<T>::value || detail::is_genfloath<T>::value)
31+
#else
32+
#define __FAST_MATH_GENFLOAT(T) (detail::is_genfloat<T>::value)
33+
#endif
34+
2835
/* ----------------- 4.13.3 Math functions. ---------------------------------*/
2936
// genfloat acos (genfloat x)
3037
template <typename T>
@@ -114,7 +121,7 @@ detail::enable_if_t<detail::is_genfloat<T>::value, T> copysign(T x,
114121

115122
// genfloat cos (genfloat x)
116123
template <typename T>
117-
detail::enable_if_t<detail::is_genfloat<T>::value, T> cos(T x) __NOEXC {
124+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> cos(T x) __NOEXC {
118125
return __sycl_std::__invoke_cos<T>(x);
119126
}
120127

@@ -144,19 +151,19 @@ detail::enable_if_t<detail::is_genfloat<T>::value, T> erf(T x) __NOEXC {
144151

145152
// genfloat exp (genfloat x )
146153
template <typename T>
147-
detail::enable_if_t<detail::is_genfloat<T>::value, T> exp(T x) __NOEXC {
154+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> exp(T x) __NOEXC {
148155
return __sycl_std::__invoke_exp<T>(x);
149156
}
150157

151158
// genfloat exp2 (genfloat x)
152159
template <typename T>
153-
detail::enable_if_t<detail::is_genfloat<T>::value, T> exp2(T x) __NOEXC {
160+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> exp2(T x) __NOEXC {
154161
return __sycl_std::__invoke_exp2<T>(x);
155162
}
156163

157164
// genfloat exp10 (genfloat x)
158165
template <typename T>
159-
detail::enable_if_t<detail::is_genfloat<T>::value, T> exp10(T x) __NOEXC {
166+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> exp10(T x) __NOEXC {
160167
return __sycl_std::__invoke_exp10<T>(x);
161168
}
162169

@@ -296,19 +303,19 @@ lgamma_r(T x, T2 signp) __NOEXC {
296303

297304
// genfloat log (genfloat x)
298305
template <typename T>
299-
detail::enable_if_t<detail::is_genfloat<T>::value, T> log(T x) __NOEXC {
306+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> log(T x) __NOEXC {
300307
return __sycl_std::__invoke_log<T>(x);
301308
}
302309

303310
// genfloat log2 (genfloat x)
304311
template <typename T>
305-
detail::enable_if_t<detail::is_genfloat<T>::value, T> log2(T x) __NOEXC {
312+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> log2(T x) __NOEXC {
306313
return __sycl_std::__invoke_log2<T>(x);
307314
}
308315

309316
// genfloat log10 (genfloat x)
310317
template <typename T>
311-
detail::enable_if_t<detail::is_genfloat<T>::value, T> log10(T x) __NOEXC {
318+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> log10(T x) __NOEXC {
312319
return __sycl_std::__invoke_log10<T>(x);
313320
}
314321

@@ -383,7 +390,7 @@ pown(T x, T2 y) __NOEXC {
383390

384391
// genfloat powr (genfloat x, genfloat y)
385392
template <typename T>
386-
detail::enable_if_t<detail::is_genfloat<T>::value, T> powr(T x, T y) __NOEXC {
393+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> powr(T x, T y) __NOEXC {
387394
return __sycl_std::__invoke_powr<T>(x, y);
388395
}
389396

@@ -426,13 +433,13 @@ detail::enable_if_t<detail::is_genfloat<T>::value, T> round(T x) __NOEXC {
426433

427434
// genfloat rsqrt (genfloat x)
428435
template <typename T>
429-
detail::enable_if_t<detail::is_genfloat<T>::value, T> rsqrt(T x) __NOEXC {
436+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> rsqrt(T x) __NOEXC {
430437
return __sycl_std::__invoke_rsqrt<T>(x);
431438
}
432439

433440
// genfloat sin (genfloat x)
434441
template <typename T>
435-
detail::enable_if_t<detail::is_genfloat<T>::value, T> sin(T x) __NOEXC {
442+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> sin(T x) __NOEXC {
436443
return __sycl_std::__invoke_sin<T>(x);
437444
}
438445

@@ -459,13 +466,13 @@ detail::enable_if_t<detail::is_genfloat<T>::value, T> sinpi(T x) __NOEXC {
459466

460467
// genfloat sqrt (genfloat x)
461468
template <typename T>
462-
detail::enable_if_t<detail::is_genfloat<T>::value, T> sqrt(T x) __NOEXC {
469+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> sqrt(T x) __NOEXC {
463470
return __sycl_std::__invoke_sqrt<T>(x);
464471
}
465472

466473
// genfloat tan (genfloat x)
467474
template <typename T>
468-
detail::enable_if_t<detail::is_genfloat<T>::value, T> tan(T x) __NOEXC {
475+
detail::enable_if_t<__FAST_MATH_GENFLOAT(T), T> tan(T x) __NOEXC {
469476
return __sycl_std::__invoke_tan<T>(x);
470477
}
471478

@@ -1561,6 +1568,82 @@ detail::enable_if_t<detail::is_genfloatf<T>::value, T> tan(T x) __NOEXC {
15611568
}
15621569

15631570
} // namespace half_precision
1571+
1572+
#ifdef __FAST_MATH__
1573+
/* ----------------- -ffast-math functions. ---------------------------------*/
1574+
// genfloatf cos (genfloatf x)
1575+
template <typename T>
1576+
detail::enable_if_t<detail::is_genfloat<T>::value, T> cos(T x) __NOEXC {
1577+
return native::cos(x);
1578+
}
1579+
1580+
// genfloatf exp (genfloatf x)
1581+
template <typename T>
1582+
detail::enable_if_t<detail::is_genfloat<T>::value, T> exp(T x) __NOEXC {
1583+
return native::exp(x);
1584+
}
1585+
1586+
// genfloatf exp2 (genfloatf x)
1587+
template <typename T>
1588+
detail::enable_if_t<detail::is_genfloat<T>::value, T> exp2(T x) __NOEXC {
1589+
return native::exp2(x);
1590+
}
1591+
1592+
// genfloatf exp10 (genfloatf x)
1593+
template <typename T>
1594+
detail::enable_if_t<detail::is_genfloat<T>::value, T> exp10(T x) __NOEXC {
1595+
return native::exp10(x);
1596+
}
1597+
1598+
// genfloatf log(genfloatf x)
1599+
template <typename T>
1600+
detail::enable_if_t<detail::is_genfloat<T>::value, T> log(T x) __NOEXC {
1601+
return native::log(x);
1602+
}
1603+
1604+
// genfloatf log2 (genfloatf x)
1605+
template <typename T>
1606+
detail::enable_if_t<detail::is_genfloat<T>::value, T> log2(T x) __NOEXC {
1607+
return native::log2(x);
1608+
}
1609+
1610+
// genfloatf log10 (genfloatf x)
1611+
template <typename T>
1612+
detail::enable_if_t<detail::is_genfloat<T>::value, T> log10(T x) __NOEXC {
1613+
return native::log10(x);
1614+
}
1615+
1616+
// genfloatf powr (genfloatf x)
1617+
template <typename T>
1618+
detail::enable_if_t<detail::is_genfloat<T>::value, T> powr(T x, T y) __NOEXC {
1619+
return native::powr(x, y);
1620+
}
1621+
1622+
// genfloatf rsqrt (genfloatf x)
1623+
template <typename T>
1624+
detail::enable_if_t<detail::is_genfloat<T>::value, T> rsqrt(T x) __NOEXC {
1625+
return native::rsqrt(x);
1626+
}
1627+
1628+
// genfloatf sin (genfloatf x)
1629+
template <typename T>
1630+
detail::enable_if_t<detail::is_genfloat<T>::value, T> sin(T x) __NOEXC {
1631+
return native::sin(x);
1632+
}
1633+
1634+
// genfloatf sqrt (genfloatf x)
1635+
template <typename T>
1636+
detail::enable_if_t<detail::is_genfloat<T>::value, T> sqrt(T x) __NOEXC {
1637+
return native::sqrt(x);
1638+
}
1639+
1640+
// genfloatf tan (genfloatf x)
1641+
template <typename T>
1642+
detail::enable_if_t<detail::is_genfloat<T>::value, T> tan(T x) __NOEXC {
1643+
return native::tan(x);
1644+
}
1645+
1646+
#endif // __FAST_MATH__
15641647
} // namespace sycl
15651648
} // __SYCL_INLINE_NAMESPACE(cl)
15661649

0 commit comments

Comments
 (0)