Skip to content

Commit fbf735a

Browse files
authored
[SYCL][Devicelib] Implement cmath rintf wrapper with __spirv_ocl_rint (#18857)
This PR is to support the use of std::rint in device code. Currently it is resolved to rintf symbol. With this PR, the rintf symbol is resolved by libdevice.
1 parent 900f0d6 commit fbf735a

File tree

3 files changed

+22
-8
lines changed

3 files changed

+22
-8
lines changed

libdevice/cmath_wrapper.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,17 @@ float nearbyintf(float x) { return __nv_nearbyintf(x); }
198198
extern "C" SYCL_EXTERNAL float __nv_rintf(float);
199199
DEVICE_EXTERN_C_INLINE
200200
float rintf(float x) { return __nv_rintf(x); }
201-
#endif // __NVPTX__
202-
203-
#ifdef __AMDGCN__
201+
#elif defined(__AMDGCN__)
204202
extern "C" SYCL_EXTERNAL float __ocml_nearbyint_f32(float);
205203
DEVICE_EXTERN_C_INLINE
206204
float nearbyintf(float x) { return __ocml_nearbyint_f32(x); }
207205

208206
extern "C" SYCL_EXTERNAL float __ocml_rint_f32(float);
209207
DEVICE_EXTERN_C_INLINE
210208
float rintf(float x) { return __ocml_rint_f32(x); }
211-
#endif // __AMDGCN__
209+
#else
210+
DEVICE_EXTERN_C_INLINE
211+
float rintf(float x) { return __spirv_ocl_rint(x); }
212+
#endif
212213

213214
#endif // __SPIR__ || __SPIRV__ || __NVPTX__ || __AMDGCN__

libdevice/cmath_wrapper_fp64.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,18 @@ double nearbyint(double x) { return __nv_nearbyint(x); }
188188
extern "C" SYCL_EXTERNAL double __nv_rint(double);
189189
DEVICE_EXTERN_C_INLINE
190190
double rint(double x) { return __nv_rint(x); }
191-
#endif // __NVPTX__
192-
193-
#ifdef __AMDGCN__
191+
#elif defined(__AMDGCN__)
194192
extern "C" SYCL_EXTERNAL double __ocml_nearbyint_f64(double);
195193
DEVICE_EXTERN_C_INLINE
196194
double nearbyint(double x) { return __ocml_nearbyint_f64(x); }
197195

198196
extern "C" SYCL_EXTERNAL double __ocml_rint_f64(double);
199197
DEVICE_EXTERN_C_INLINE
200198
double rint(double x) { return __ocml_rint_f64(x); }
201-
#endif // __AMDGCN__
199+
#else
200+
DEVICE_EXTERN_C_INLINE
201+
double rint(double x) { return __spirv_ocl_rint(x); }
202+
#endif
202203

203204
#if defined(_MSC_VER)
204205
#include <math.h>

libdevice/test/check_cmath.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
REQUIRES: libsycldevice
2+
3+
Check functions in fp32 libdevice spirv file.
4+
5+
RUN: llvm-spirv --spirv-target-env=SPV-IR -r %libsycldevice_spv_dir/libsycl-cmath.spv -o %t.bc
6+
RUN: llvm-dis %t.bc -o %t.ll
7+
RUN: FileCheck %s --input-file %t.ll
8+
9+
CHECK: target triple ={{.*}}spir64
10+
11+
CHECK-LABEL: define spir_func float @rintf(
12+
CHECK: call spir_func float @_Z16__spirv_ocl_rintf(

0 commit comments

Comments
 (0)