Skip to content

Commit 894d036

Browse files
[naga spv-out msl-out hlsl-out] Make infinite loop workaround count down instead of up (#7372)
[naga spv-out msl-out hlsl-out] Make infinite loop workaround count down instead of up To avoid generating code containing infinite loops, and therefore incurring the wrath of undefined behaviour, we insert a counter into each loop that will break after 2^64 iterations. This was previously implemented as two u32 variables counting up from zero. We have been informed that this construct can cause certain Intel drivers to hang. Instead, we must count down from u32::MAX. Counting down is more fun, anyway. Co-authored-by: Erich Gubler <erichdongubler@gmail.com>
1 parent f89ede7 commit 894d036

33 files changed

+255
-250
lines changed

naga/src/back/hlsl/writer.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
170170
}
171171

172172
let loop_bound_name = self.namer.call("loop_bound");
173-
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
174-
let level = level.next();
175173
let max = u32::MAX;
174+
// Count down from u32::MAX rather than up from 0 to avoid hang on
175+
// certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
176+
let decl = format!("{level}uint2 {loop_bound_name} = uint2({max}u, {max}u);");
177+
let level = level.next();
176178
let break_and_inc = format!(
177-
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
178-
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
179+
"{level}if (all({loop_bound_name} == uint2(0u, 0u))) {{ break; }}
180+
{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
179181
);
180182

181183
Some((decl, break_and_inc))

naga/src/back/msl/writer.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,12 +850,13 @@ impl<W: Write> Writer<W> {
850850
}
851851

852852
let loop_bound_name = self.namer.call("loop_bound");
853-
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
853+
// Count down from u32::MAX rather than up from 0 to avoid hang on
854+
// certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
855+
let decl = format!("{level}uint2 {loop_bound_name} = uint2({}u);", u32::MAX);
854856
let level = level.next();
855-
let max = u32::MAX;
856857
let break_and_inc = format!(
857-
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
858-
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
858+
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2(0u))) {{ break; }}
859+
{level}{loop_bound_name} -= uint2({loop_bound_name}.y == 0u, 1u);"
859860
);
860861

861862
Some((decl, break_and_inc))

naga/src/back/spv/block.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ impl BlockContext<'_> {
310310
uint2_ptr_type_id,
311311
loop_counter_var_id,
312312
spirv::StorageClass::Function,
313-
Some(zero_uint2_const_id),
313+
Some(max_uint2_const_id),
314314
),
315315
};
316316
self.function.force_loop_bounding_vars.push(var);
@@ -331,14 +331,14 @@ impl BlockContext<'_> {
331331
None,
332332
));
333333

334-
// If both the high and low u32s have reached u32::MAX then break. ie
335-
// if (all(eq(loop_counter, vec2(u32::MAX)))) { break; }
334+
// If both the high and low u32s have reached 0 then break. ie
335+
// if (all(eq(loop_counter, vec2(0)))) { break; }
336336
let eq_id = self.gen_id();
337337
block.body.push(Instruction::binary(
338338
spirv::Op::IEqual,
339339
bool2_type_id,
340340
eq_id,
341-
max_uint2_const_id,
341+
zero_uint2_const_id,
342342
load_id,
343343
));
344344
let all_eq_id = self.gen_id();
@@ -360,9 +360,11 @@ impl BlockContext<'_> {
360360
);
361361
block = Block::new(inc_counter_block_id);
362362

363-
// To simulate a 64-bit counter we always increment the low u32, and increment
363+
// To simulate a 64-bit counter we always decrement the low u32, and decrement
364364
// the high u32 when the low u32 overflows. ie
365-
// counter += vec2(select(0u, 1u, counter.y == u32::MAX), 1u);
365+
// counter -= vec2(select(0u, 1u, counter.y == 0), 1u);
366+
// Count down from u32::MAX rather than up from 0 to avoid hang on
367+
// certain Intel drivers. See <https://github.com/gfx-rs/wgpu/issues/7319>.
366368
let low_id = self.gen_id();
367369
block.body.push(Instruction::composite_extract(
368370
uint_type_id,
@@ -376,7 +378,7 @@ impl BlockContext<'_> {
376378
bool_type_id,
377379
low_overflow_id,
378380
low_id,
379-
max_uint_const_id,
381+
zero_uint_const_id,
380382
));
381383
let carry_bit_id = self.gen_id();
382384
block.body.push(Instruction::select(
@@ -386,19 +388,19 @@ impl BlockContext<'_> {
386388
one_uint_const_id,
387389
zero_uint_const_id,
388390
));
389-
let increment_id = self.gen_id();
391+
let decrement_id = self.gen_id();
390392
block.body.push(Instruction::composite_construct(
391393
uint2_type_id,
392-
increment_id,
394+
decrement_id,
393395
&[carry_bit_id, one_uint_const_id],
394396
));
395397
let result_id = self.gen_id();
396398
block.body.push(Instruction::binary(
397-
spirv::Op::IAdd,
399+
spirv::Op::ISub,
398400
uint2_type_id,
399401
result_id,
400402
load_id,
401-
increment_id,
403+
decrement_id,
402404
));
403405
block
404406
.body

naga/tests/out/hlsl/boids.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
4141
vPos = _e8;
4242
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
4343
vVel = _e14;
44-
uint2 loop_bound = uint2(0u, 0u);
44+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
4545
bool loop_init = true;
4646
while(true) {
47-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
48-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
47+
if (all(loop_bound == uint2(0u, 0u))) { break; }
48+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
4949
if (!loop_init) {
5050
uint _e91 = i;
5151
i = (_e91 + 1u);

naga/tests/out/hlsl/break-if.hlsl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
void breakIfEmpty()
22
{
3-
uint2 loop_bound = uint2(0u, 0u);
3+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
44
bool loop_init = true;
55
while(true) {
6-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
7-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
6+
if (all(loop_bound == uint2(0u, 0u))) { break; }
7+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
88
if (!loop_init) {
99
if (true) {
1010
break;
@@ -20,11 +20,11 @@ void breakIfEmptyBody(bool a)
2020
bool b = (bool)0;
2121
bool c = (bool)0;
2222

23-
uint2 loop_bound_1 = uint2(0u, 0u);
23+
uint2 loop_bound_1 = uint2(4294967295u, 4294967295u);
2424
bool loop_init_1 = true;
2525
while(true) {
26-
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
27-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
26+
if (all(loop_bound_1 == uint2(0u, 0u))) { break; }
27+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
2828
if (!loop_init_1) {
2929
b = a;
3030
bool _e2 = b;
@@ -44,11 +44,11 @@ void breakIf(bool a_1)
4444
bool d = (bool)0;
4545
bool e = (bool)0;
4646

47-
uint2 loop_bound_2 = uint2(0u, 0u);
47+
uint2 loop_bound_2 = uint2(4294967295u, 4294967295u);
4848
bool loop_init_2 = true;
4949
while(true) {
50-
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
51-
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
50+
if (all(loop_bound_2 == uint2(0u, 0u))) { break; }
51+
loop_bound_2 -= uint2(loop_bound_2.y == 0u, 1u);
5252
if (!loop_init_2) {
5353
bool _e5 = e;
5454
if ((a_1 == _e5)) {
@@ -67,11 +67,11 @@ void breakIfSeparateVariable()
6767
{
6868
uint counter = 0u;
6969

70-
uint2 loop_bound_3 = uint2(0u, 0u);
70+
uint2 loop_bound_3 = uint2(4294967295u, 4294967295u);
7171
bool loop_init_3 = true;
7272
while(true) {
73-
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
74-
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
73+
if (all(loop_bound_3 == uint2(0u, 0u))) { break; }
74+
loop_bound_3 -= uint2(loop_bound_3.y == 0u, 1u);
7575
if (!loop_init_3) {
7676
uint _e5 = counter;
7777
if ((_e5 == 5u)) {

naga/tests/out/hlsl/collatz.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ uint collatz_iterations(uint n_base)
1414
uint i = 0u;
1515

1616
n = n_base;
17-
uint2 loop_bound = uint2(0u, 0u);
17+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
1818
while(true) {
19-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
20-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
19+
if (all(loop_bound == uint2(0u, 0u))) { break; }
20+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
2121
uint _e4 = n;
2222
if ((_e4 > 1u)) {
2323
} else {

naga/tests/out/hlsl/control-flow.hlsl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ void switch_const_expr_case_selectors()
6464

6565
void loop_switch_continue(int x)
6666
{
67-
uint2 loop_bound = uint2(0u, 0u);
67+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
6868
while(true) {
69-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
70-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
69+
if (all(loop_bound == uint2(0u, 0u))) { break; }
70+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
7171
bool should_continue = false;
7272
switch(x) {
7373
case 1: {
@@ -87,10 +87,10 @@ void loop_switch_continue(int x)
8787

8888
void loop_switch_continue_nesting(int x_1, int y, int z)
8989
{
90-
uint2 loop_bound_1 = uint2(0u, 0u);
90+
uint2 loop_bound_1 = uint2(4294967295u, 4294967295u);
9191
while(true) {
92-
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
93-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
92+
if (all(loop_bound_1 == uint2(0u, 0u))) { break; }
93+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
9494
bool should_continue_1 = false;
9595
switch(x_1) {
9696
case 1: {
@@ -104,10 +104,10 @@ void loop_switch_continue_nesting(int x_1, int y, int z)
104104
break;
105105
}
106106
default: {
107-
uint2 loop_bound_2 = uint2(0u, 0u);
107+
uint2 loop_bound_2 = uint2(4294967295u, 4294967295u);
108108
while(true) {
109-
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
110-
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
109+
if (all(loop_bound_2 == uint2(0u, 0u))) { break; }
110+
loop_bound_2 -= uint2(loop_bound_2.y == 0u, 1u);
111111
bool should_continue_2 = false;
112112
switch(z) {
113113
case 1: {
@@ -146,10 +146,10 @@ void loop_switch_continue_nesting(int x_1, int y, int z)
146146
continue;
147147
}
148148
}
149-
uint2 loop_bound_3 = uint2(0u, 0u);
149+
uint2 loop_bound_3 = uint2(4294967295u, 4294967295u);
150150
while(true) {
151-
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
152-
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
151+
if (all(loop_bound_3 == uint2(0u, 0u))) { break; }
152+
loop_bound_3 -= uint2(loop_bound_3.y == 0u, 1u);
153153
bool should_continue_4 = false;
154154
do {
155155
do {
@@ -171,10 +171,10 @@ void loop_switch_omit_continue_variable_checks(int x_2, int y_1, int z_1, int w)
171171
{
172172
int pos_1 = int(0);
173173

174-
uint2 loop_bound_4 = uint2(0u, 0u);
174+
uint2 loop_bound_4 = uint2(4294967295u, 4294967295u);
175175
while(true) {
176-
if (all(loop_bound_4 == uint2(4294967295u, 4294967295u))) { break; }
177-
loop_bound_4 += uint2(loop_bound_4.y == 4294967295u, 1u);
176+
if (all(loop_bound_4 == uint2(0u, 0u))) { break; }
177+
loop_bound_4 -= uint2(loop_bound_4.y == 0u, 1u);
178178
bool should_continue_5 = false;
179179
switch(x_2) {
180180
case 1: {
@@ -186,10 +186,10 @@ void loop_switch_omit_continue_variable_checks(int x_2, int y_1, int z_1, int w)
186186
}
187187
}
188188
}
189-
uint2 loop_bound_5 = uint2(0u, 0u);
189+
uint2 loop_bound_5 = uint2(4294967295u, 4294967295u);
190190
while(true) {
191-
if (all(loop_bound_5 == uint2(4294967295u, 4294967295u))) { break; }
192-
loop_bound_5 += uint2(loop_bound_5.y == 4294967295u, 1u);
191+
if (all(loop_bound_5 == uint2(0u, 0u))) { break; }
192+
loop_bound_5 -= uint2(loop_bound_5.y == 0u, 1u);
193193
bool should_continue_6 = false;
194194
switch(x_2) {
195195
case 1: {

naga/tests/out/hlsl/do-while.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
void fb1_(inout bool cond)
22
{
3-
uint2 loop_bound = uint2(0u, 0u);
3+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
44
bool loop_init = true;
55
while(true) {
6-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
7-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
6+
if (all(loop_bound == uint2(0u, 0u))) { break; }
7+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
88
if (!loop_init) {
99
bool _e1 = cond;
1010
if (!(_e1)) {

naga/tests/out/hlsl/ray-query.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructu
8484
RayQuery<RAY_FLAG_NONE> rq_1;
8585

8686
rq_1.TraceRayInline(acs, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir)));
87-
uint2 loop_bound = uint2(0u, 0u);
87+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
8888
while(true) {
89-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
90-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
89+
if (all(loop_bound == uint2(0u, 0u))) { break; }
90+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
9191
const bool _e9 = rq_1.Proceed();
9292
if (_e9) {
9393
} else {

naga/tests/out/hlsl/shadow.hlsl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ float4 fs_main(FragmentInput_fs_main fragmentinput_fs_main) : SV_Target0
9595
uint i = 0u;
9696

9797
float3 normal_1 = normalize(in_.world_normal);
98-
uint2 loop_bound = uint2(0u, 0u);
98+
uint2 loop_bound = uint2(4294967295u, 4294967295u);
9999
bool loop_init = true;
100100
while(true) {
101-
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
102-
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
101+
if (all(loop_bound == uint2(0u, 0u))) { break; }
102+
loop_bound -= uint2(loop_bound.y == 0u, 1u);
103103
if (!loop_init) {
104104
uint _e40 = i;
105105
i = (_e40 + 1u);
@@ -134,11 +134,11 @@ float4 fs_main_without_storage(FragmentInput_fs_main_without_storage fragmentinp
134134
uint i_1 = 0u;
135135

136136
float3 normal_2 = normalize(in_1.world_normal);
137-
uint2 loop_bound_1 = uint2(0u, 0u);
137+
uint2 loop_bound_1 = uint2(4294967295u, 4294967295u);
138138
bool loop_init_1 = true;
139139
while(true) {
140-
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
141-
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
140+
if (all(loop_bound_1 == uint2(0u, 0u))) { break; }
141+
loop_bound_1 -= uint2(loop_bound_1.y == 0u, 1u);
142142
if (!loop_init_1) {
143143
uint _e40 = i_1;
144144
i_1 = (_e40 + 1u);

0 commit comments

Comments
 (0)