Skip to content

Commit 23b81da

Browse files
authored
[hlsl-out] polyfill float remainder operator (#7750)
1 parent 28af245 commit 23b81da

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

naga/src/back/hlsl/help.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ impl<W: Write> super::Writer<'_, W> {
13181318
crate::BinaryOperator::Modulo,
13191319
Some(
13201320
scalar @ crate::Scalar {
1321-
kind: ScalarKind::Sint | ScalarKind::Uint,
1321+
kind: ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float,
13221322
..
13231323
},
13241324
),
@@ -1367,6 +1367,14 @@ impl<W: Write> super::Writer<'_, W> {
13671367
ScalarKind::Uint => {
13681368
writeln!(self.out, "{level}return lhs % (rhs == 0u ? 1u : rhs);")?
13691369
}
1370+
// HLSL's fmod has the same definition as WGSL's % operator but due
1371+
// to its implementation in DXC it is not as accurate as the WGSL spec
1372+
// requires it to be. See:
1373+
// - https://shader-playground.timjones.io/0c8572816dbb6fc4435cc5d016a978a7
1374+
// - https://github.com/llvm/llvm-project/blob/50f9b8acafdca48e87e6b8e393c1f116a2d193ee/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h#L78-L81
1375+
ScalarKind::Float => {
1376+
writeln!(self.out, "{level}return lhs - rhs * trunc(lhs / rhs);")?
1377+
}
13701378
_ => unreachable!(),
13711379
}
13721380
writeln!(self.out, "}}")?;

naga/src/back/hlsl/writer.rs

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2924,7 +2924,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
29242924
right,
29252925
} if matches!(
29262926
func_ctx.resolve_type(expr, &module.types).scalar_kind(),
2927-
Some(ScalarKind::Sint | ScalarKind::Uint)
2927+
Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float)
29282928
) =>
29292929
{
29302930
write!(self.out, "{MOD_FUNCTION}(")?;
@@ -2934,21 +2934,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
29342934
write!(self.out, ")")?;
29352935
}
29362936

2937-
// While HLSL supports float operands with the % operator it is only
2938-
// defined in cases where both sides are either positive or negative.
2939-
Expression::Binary {
2940-
op: crate::BinaryOperator::Modulo,
2941-
left,
2942-
right,
2943-
} if func_ctx.resolve_type(left, &module.types).scalar_kind()
2944-
== Some(ScalarKind::Float) =>
2945-
{
2946-
write!(self.out, "fmod(")?;
2947-
self.write_expr(module, left, func_ctx)?;
2948-
write!(self.out, ", ")?;
2949-
self.write_expr(module, right, func_ctx)?;
2950-
write!(self.out, ")")?;
2951-
}
29522937
Expression::Binary { op, left, right } => {
29532938
write!(self.out, "(")?;
29542939
self.write_expr(module, left, func_ctx)?;

naga/tests/out/hlsl/wgsl-operators.hlsl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ uint naga_mod(uint lhs, uint rhs) {
9090
return lhs % (rhs == 0u ? 1u : rhs);
9191
}
9292

93+
float naga_mod(float lhs, float rhs) {
94+
return lhs - rhs * trunc(lhs / rhs);
95+
}
96+
9397
int2 naga_mod(int2 lhs, int2 rhs) {
9498
int2 divisor = ((lhs == int(-2147483647 - 1) & rhs == -1) | (rhs == 0)) ? 1 : rhs;
9599
return lhs - (lhs / divisor) * divisor;
@@ -99,6 +103,10 @@ uint3 naga_mod(uint3 lhs, uint3 rhs) {
99103
return lhs % (rhs == 0u ? 1u : rhs);
100104
}
101105

106+
float4 naga_mod(float4 lhs, float4 rhs) {
107+
return lhs - rhs * trunc(lhs / rhs);
108+
}
109+
102110
uint2 naga_div(uint2 lhs, uint2 rhs) {
103111
return lhs / (rhs == 0u ? 1u : rhs);
104112
}
@@ -107,6 +115,10 @@ uint2 naga_mod(uint2 lhs, uint2 rhs) {
107115
return lhs % (rhs == 0u ? 1u : rhs);
108116
}
109117

118+
float2 naga_mod(float2 lhs, float2 rhs) {
119+
return lhs - rhs * trunc(lhs / rhs);
120+
}
121+
110122
float3x3 ZeroValuefloat3x3() {
111123
return (float3x3)0;
112124
}
@@ -153,10 +165,10 @@ void arithmetic()
153165
float4 div5_ = ((2.0).xxxx / (1.0).xxxx);
154166
int rem0_ = naga_mod(int(2), int(1));
155167
uint rem1_ = naga_mod(2u, 1u);
156-
float rem2_ = fmod(2.0, 1.0);
168+
float rem2_ = naga_mod(2.0, 1.0);
157169
int2 rem3_ = naga_mod((int(2)).xx, (int(1)).xx);
158170
uint3 rem4_ = naga_mod((2u).xxx, (1u).xxx);
159-
float4 rem5_ = fmod((2.0).xxxx, (1.0).xxxx);
171+
float4 rem5_ = naga_mod((2.0).xxxx, (1.0).xxxx);
160172
{
161173
int2 add0_1 = asint(asuint((int(2)).xx) + asuint((int(1)).xx));
162174
int2 add1_1 = asint(asuint((int(2)).xx) + asuint((int(1)).xx));
@@ -186,8 +198,8 @@ void arithmetic()
186198
int2 rem1_1 = naga_mod((int(2)).xx, (int(1)).xx);
187199
uint2 rem2_1 = naga_mod((2u).xx, (1u).xx);
188200
uint2 rem3_1 = naga_mod((2u).xx, (1u).xx);
189-
float2 rem4_1 = fmod((2.0).xx, (1.0).xx);
190-
float2 rem5_1 = fmod((2.0).xx, (1.0).xx);
201+
float2 rem4_1 = naga_mod((2.0).xx, (1.0).xx);
202+
float2 rem5_1 = naga_mod((2.0).xx, (1.0).xx);
191203
}
192204
float3x3 add = (ZeroValuefloat3x3() + ZeroValuefloat3x3());
193205
float3x3 sub = (ZeroValuefloat3x3() - ZeroValuefloat3x3());

0 commit comments

Comments
 (0)