Skip to content

Commit c3e08c8

Browse files
authored
[HLSL] Add new double overloads for math builtins (#132979)
Add double overloads which cast the double to a float and call the float builtin. Makes these double overloads conditional on hlsl version 202x or earlier. Add tests Closes #128228
1 parent 1cc07a0 commit c3e08c8

32 files changed

+762
-2
lines changed

clang/lib/Frontend/InitPreprocessor.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,10 @@ static void InitializeStandardPredefinedMacros(const TargetInfo &TI,
394394
// HLSL Version
395395
Builder.defineMacro("__HLSL_VERSION",
396396
Twine((unsigned)LangOpts.getHLSLVersion()));
397+
Builder.defineMacro("__HLSL_202x",
398+
Twine((unsigned)LangOptions::HLSLLangStd::HLSL_202x));
399+
Builder.defineMacro("__HLSL_202y",
400+
Twine((unsigned)LangOptions::HLSLLangStd::HLSL_202y));
397401

398402
if (LangOpts.NativeHalfType)
399403
Builder.defineMacro("__HLSL_ENABLE_16_BIT", "1");

clang/lib/Headers/hlsl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323
// HLSL standard library function declarations/definitions.
2424
#include "hlsl/hlsl_alias_intrinsics.h"
25+
#if __HLSL_VERSION <= __HLSL_202x
2526
#include "hlsl/hlsl_compat_overloads.h"
27+
#endif
2628
#include "hlsl/hlsl_intrinsics.h"
2729

2830
#if defined(__clang__)

clang/lib/Headers/hlsl/hlsl_compat_overloads.h

Lines changed: 213 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,70 @@ namespace hlsl {
1616
// unsigned integer and floating point. Keeping this ordering consistent will
1717
// help keep this file manageable as it grows.
1818

19+
#define _DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(fn) \
20+
constexpr float fn(double V) { return fn((float)V); } \
21+
constexpr float2 fn(double2 V) { return fn((float2)V); } \
22+
constexpr float3 fn(double3 V) { return fn((float3)V); } \
23+
constexpr float4 fn(double4 V) { return fn((float4)V); }
24+
25+
#define _DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(fn) \
26+
constexpr float fn(double V1, double V2) { \
27+
return fn((float)V1, (float)V2); \
28+
} \
29+
constexpr float2 fn(double2 V1, double2 V2) { \
30+
return fn((float2)V1, (float2)V2); \
31+
} \
32+
constexpr float3 fn(double3 V1, double3 V2) { \
33+
return fn((float3)V1, (float3)V2); \
34+
} \
35+
constexpr float4 fn(double4 V1, double4 V2) { \
36+
return fn((float4)V1, (float4)V2); \
37+
}
38+
39+
#define _DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(fn) \
40+
constexpr float fn(double V1, double V2, double V3) { \
41+
return fn((float)V1, (float)V2, (float)V3); \
42+
} \
43+
constexpr float2 fn(double2 V1, double2 V2, double2 V3) { \
44+
return fn((float2)V1, (float2)V2, (float2)V3); \
45+
} \
46+
constexpr float3 fn(double3 V1, double3 V2, double3 V3) { \
47+
return fn((float3)V1, (float3)V2, (float3)V3); \
48+
} \
49+
constexpr float4 fn(double4 V1, double4 V2, double4 V3) { \
50+
return fn((float4)V1, (float4)V2, (float4)V3); \
51+
}
52+
53+
//===----------------------------------------------------------------------===//
54+
// acos builtins overloads
55+
//===----------------------------------------------------------------------===//
56+
57+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(acos)
58+
59+
//===----------------------------------------------------------------------===//
60+
// asin builtins overloads
61+
//===----------------------------------------------------------------------===//
62+
63+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(asin)
64+
65+
//===----------------------------------------------------------------------===//
66+
// atan builtins overloads
67+
//===----------------------------------------------------------------------===//
68+
69+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(atan)
70+
71+
//===----------------------------------------------------------------------===//
72+
// atan2 builtins overloads
73+
//===----------------------------------------------------------------------===//
74+
75+
_DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(atan2)
76+
77+
//===----------------------------------------------------------------------===//
78+
// ceil builtins overloads
79+
//===----------------------------------------------------------------------===//
80+
81+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(ceil)
82+
1983
//===----------------------------------------------------------------------===//
2084
// clamp builtins overloads
2185
//===----------------------------------------------------------------------===//
@@ -39,7 +103,82 @@ clamp(vector<T, N> p0, T p1, T p2) {
39103
}
40104

41105
//===----------------------------------------------------------------------===//
42-
// max builtin overloads
106+
// cos builtins overloads
107+
//===----------------------------------------------------------------------===//
108+
109+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(cos)
110+
111+
//===----------------------------------------------------------------------===//
112+
// cosh builtins overloads
113+
//===----------------------------------------------------------------------===//
114+
115+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(cosh)
116+
117+
//===----------------------------------------------------------------------===//
118+
// degrees builtins overloads
119+
//===----------------------------------------------------------------------===//
120+
121+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(degrees)
122+
123+
//===----------------------------------------------------------------------===//
124+
// exp builtins overloads
125+
//===----------------------------------------------------------------------===//
126+
127+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(exp)
128+
129+
//===----------------------------------------------------------------------===//
130+
// exp2 builtins overloads
131+
//===----------------------------------------------------------------------===//
132+
133+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(exp2)
134+
135+
//===----------------------------------------------------------------------===//
136+
// floor builtins overloads
137+
//===----------------------------------------------------------------------===//
138+
139+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(floor)
140+
141+
//===----------------------------------------------------------------------===//
142+
// frac builtins overloads
143+
//===----------------------------------------------------------------------===//
144+
145+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(frac)
146+
147+
//===----------------------------------------------------------------------===//
148+
// isinf builtins overloads
149+
//===----------------------------------------------------------------------===//
150+
151+
constexpr bool isinf(double V) { return isinf((float)V); }
152+
constexpr bool2 isinf(double2 V) { return isinf((float2)V); }
153+
constexpr bool3 isinf(double3 V) { return isinf((float3)V); }
154+
constexpr bool4 isinf(double4 V) { return isinf((float4)V); }
155+
156+
//===----------------------------------------------------------------------===//
157+
// lerp builtins overloads
158+
//===----------------------------------------------------------------------===//
159+
160+
_DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp)
161+
162+
//===----------------------------------------------------------------------===//
163+
// log builtins overloads
164+
//===----------------------------------------------------------------------===//
165+
166+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(log)
167+
168+
//===----------------------------------------------------------------------===//
169+
// log10 builtins overloads
170+
//===----------------------------------------------------------------------===//
171+
172+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(log10)
173+
174+
//===----------------------------------------------------------------------===//
175+
// log2 builtins overloads
176+
//===----------------------------------------------------------------------===//
177+
178+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(log2)
179+
180+
//===----------------------------------------------------------------------===//
181+
// max builtins overloads
43182
//===----------------------------------------------------------------------===//
44183

45184
template <typename T, uint N>
@@ -55,7 +194,7 @@ max(T p0, vector<T, N> p1) {
55194
}
56195

57196
//===----------------------------------------------------------------------===//
58-
// min builtin overloads
197+
// min builtins overloads
59198
//===----------------------------------------------------------------------===//
60199

61200
template <typename T, uint N>
@@ -70,5 +209,77 @@ min(T p0, vector<T, N> p1) {
70209
return min((vector<T, N>)p0, p1);
71210
}
72211

212+
//===----------------------------------------------------------------------===//
213+
// normalize builtins overloads
214+
//===----------------------------------------------------------------------===//
215+
216+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(normalize)
217+
218+
//===----------------------------------------------------------------------===//
219+
// pow builtins overloads
220+
//===----------------------------------------------------------------------===//
221+
222+
_DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(pow)
223+
224+
//===----------------------------------------------------------------------===//
225+
// rsqrt builtins overloads
226+
//===----------------------------------------------------------------------===//
227+
228+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(rsqrt)
229+
230+
//===----------------------------------------------------------------------===//
231+
// round builtins overloads
232+
//===----------------------------------------------------------------------===//
233+
234+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(round)
235+
236+
//===----------------------------------------------------------------------===//
237+
// sin builtins overloads
238+
//===----------------------------------------------------------------------===//
239+
240+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(sin)
241+
242+
//===----------------------------------------------------------------------===//
243+
// sinh builtins overloads
244+
//===----------------------------------------------------------------------===//
245+
246+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(sinh)
247+
248+
//===----------------------------------------------------------------------===//
249+
// sqrt builtins overloads
250+
//===----------------------------------------------------------------------===//
251+
252+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(sqrt)
253+
254+
//===----------------------------------------------------------------------===//
255+
// step builtins overloads
256+
//===----------------------------------------------------------------------===//
257+
258+
_DXC_COMPAT_BINARY_DOUBLE_OVERLOADS(step)
259+
260+
//===----------------------------------------------------------------------===//
261+
// tan builtins overloads
262+
//===----------------------------------------------------------------------===//
263+
264+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(tan)
265+
266+
//===----------------------------------------------------------------------===//
267+
// tanh builtins overloads
268+
//===----------------------------------------------------------------------===//
269+
270+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(tanh)
271+
272+
//===----------------------------------------------------------------------===//
273+
// trunc builtins overloads
274+
//===----------------------------------------------------------------------===//
275+
276+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(trunc)
277+
278+
//===----------------------------------------------------------------------===//
279+
// radians builtins overloads
280+
//===----------------------------------------------------------------------===//
281+
282+
_DXC_COMPAT_UNARY_DOUBLE_OVERLOADS(radians)
283+
73284
} // namespace hlsl
74285
#endif // _HLSL_COMPAT_OVERLOADS_H_

clang/test/CodeGenHLSL/builtins/acos.hlsl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ float3 test_acos_float3 ( float3 p0 ) {
5757
float4 test_acos_float4 ( float4 p0 ) {
5858
return acos ( p0 );
5959
}
60+
61+
// CHECK-LABEL: test_acos_double
62+
// CHECK: call reassoc nnan ninf nsz arcp afn float @llvm.acos.f32
63+
float test_acos_double ( double p0 ) {
64+
return acos ( p0 );
65+
}
66+
67+
// CHECK-LABEL: test_acos_double2
68+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.acos.v2f32
69+
float2 test_acos_double2 ( double2 p0 ) {
70+
return acos ( p0 );
71+
}
72+
73+
// CHECK-LABEL: test_acos_double3
74+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.acos.v3f32
75+
float3 test_acos_double3 ( double3 p0 ) {
76+
return acos ( p0 );
77+
}
78+
79+
// CHECK-LABEL: test_acos_double4
80+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.acos.v4f32
81+
float4 test_acos_double4 ( double4 p0 ) {
82+
return acos ( p0 );
83+
}

clang/test/CodeGenHLSL/builtins/asin.hlsl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ float3 test_asin_float3 ( float3 p0 ) {
5757
float4 test_asin_float4 ( float4 p0 ) {
5858
return asin ( p0 );
5959
}
60+
61+
// CHECK-LABEL: test_asin_double
62+
// CHECK: call reassoc nnan ninf nsz arcp afn float @llvm.asin.f32
63+
float test_asin_double ( double p0 ) {
64+
return asin ( p0 );
65+
}
66+
67+
// CHECK-LABEL: test_asin_double2
68+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.asin.v2f32
69+
float2 test_asin_double2 ( double2 p0 ) {
70+
return asin ( p0 );
71+
}
72+
73+
// CHECK-LABEL: test_asin_double3
74+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.asin.v3f32
75+
float3 test_asin_double3 ( double3 p0 ) {
76+
return asin ( p0 );
77+
}
78+
79+
// CHECK-LABEL: test_asin_double4
80+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.asin.v4f32
81+
float4 test_asin_double4 ( double4 p0 ) {
82+
return asin ( p0 );
83+
}

clang/test/CodeGenHLSL/builtins/atan.hlsl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ float3 test_atan_float3 ( float3 p0 ) {
5757
float4 test_atan_float4 ( float4 p0 ) {
5858
return atan ( p0 );
5959
}
60+
61+
// CHECK-LABEL: test_atan_double
62+
// CHECK: call reassoc nnan ninf nsz arcp afn float @llvm.atan.f32
63+
float test_atan_double ( double p0 ) {
64+
return atan ( p0 );
65+
}
66+
67+
// CHECK-LABEL: test_atan_double2
68+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.atan.v2f32
69+
float2 test_atan_double2 ( double2 p0 ) {
70+
return atan ( p0 );
71+
}
72+
73+
// CHECK-LABEL: test_atan_double3
74+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.atan.v3f32
75+
float3 test_atan_double3 ( double3 p0 ) {
76+
return atan ( p0 );
77+
}
78+
79+
// CHECK-LABEL: test_atan_double4
80+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.atan.v4f32
81+
float4 test_atan_double4 ( double4 p0 ) {
82+
return atan ( p0 );
83+
}

clang/test/CodeGenHLSL/builtins/atan2.hlsl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,27 @@ float3 test_atan2_float3 (float3 p0, float3 p1) {
5757
float4 test_atan2_float4 (float4 p0, float4 p1) {
5858
return atan2(p0, p1);
5959
}
60+
61+
// CHECK-LABEL: test_atan2_double
62+
// CHECK: call reassoc nnan ninf nsz arcp afn float @llvm.atan2.f32
63+
float test_atan2_double (double p0, double p1) {
64+
return atan2(p0, p1);
65+
}
66+
67+
// CHECK-LABEL: test_atan2_double2
68+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.atan2.v2f32
69+
float2 test_atan2_double2 (double2 p0, double2 p1) {
70+
return atan2(p0, p1);
71+
}
72+
73+
// CHECK-LABEL: test_atan2_double3
74+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.atan2.v3f32
75+
float3 test_atan2_double3 (double3 p0, double3 p1) {
76+
return atan2(p0, p1);
77+
}
78+
79+
// CHECK-LABEL: test_atan2_double4
80+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.atan2.v4f32
81+
float4 test_atan2_double4 (double4 p0, double4 p1) {
82+
return atan2(p0, p1);
83+
}

clang/test/CodeGenHLSL/builtins/ceil.hlsl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,16 @@ float3 test_ceil_float3(float3 p0) { return ceil(p0); }
4040
// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> @_Z16test_ceil_float4
4141
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.ceil.v4f32(
4242
float4 test_ceil_float4(float4 p0) { return ceil(p0); }
43+
44+
// CHECK-LABEL: define noundef nofpclass(nan inf) float {{.*}}test_ceil_double
45+
// CHECK: call reassoc nnan ninf nsz arcp afn float @llvm.ceil.f32(
46+
float test_ceil_double(double p0) { return ceil(p0); }
47+
// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x float> {{.*}}test_ceil_double2
48+
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x float> @llvm.ceil.v2f32(
49+
float2 test_ceil_double2(double2 p0) { return ceil(p0); }
50+
// CHECK-LABEL: define noundef nofpclass(nan inf) <3 x float> {{.*}}test_ceil_double3
51+
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x float> @llvm.ceil.v3f32(
52+
float3 test_ceil_double3(double3 p0) { return ceil(p0); }
53+
// CHECK-LABEL: define noundef nofpclass(nan inf) <4 x float> {{.*}}test_ceil_double4
54+
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x float> @llvm.ceil.v4f32(
55+
float4 test_ceil_double4(double4 p0) { return ceil(p0); }

0 commit comments

Comments
 (0)