Skip to content

Commit bb46a7f

Browse files
authored
[naga hlsl-out, glsl-out] Support atomicCompareExchangeWeak (#7658)
1 parent 921c6ab commit bb46a7f

22 files changed

+1218
-513
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Bottom level categories:
5555

5656
- When emitting GLSL, Uniform and Storage Buffer memory layouts are now emitted even if no explicit binding is given. By @cloone8 in [#7579](https://github.com/gfx-rs/wgpu/pull/7579).
5757
- Add support for [quad operations](https://www.w3.org/TR/WGSL/#quad-builtin-functions) (requires `SUBGROUP` feature to be enabled). By @dzamkov and @valaphee in [#7683](https://github.com/gfx-rs/wgpu/pull/7683).
58+
- Add support for `atomicCompareExchangeWeak` in HLSL and GLSL backends. By @cryvosh in [#7658](https://github.com/gfx-rs/wgpu/pull/7658)
5859

5960
### Bug Fixes
6061

naga/src/back/glsl/mod.rs

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -788,26 +788,28 @@ impl<'a, W: Write> Writer<'a, W> {
788788
// you can't make a struct without adding all of its members first.
789789
for (handle, ty) in self.module.types.iter() {
790790
if let TypeInner::Struct { ref members, .. } = ty.inner {
791+
let struct_name = &self.names[&NameKey::Type(handle)];
792+
791793
// Structures ending with runtime-sized arrays can only be
792794
// rendered as shader storage blocks in GLSL, not stand-alone
793795
// struct types.
794796
if !self.module.types[members.last().unwrap().ty]
795797
.inner
796798
.is_dynamically_sized(&self.module.types)
797799
{
798-
let name = &self.names[&NameKey::Type(handle)];
799-
write!(self.out, "struct {name} ")?;
800+
write!(self.out, "struct {struct_name} ")?;
800801
self.write_struct_body(handle, members)?;
801802
writeln!(self.out, ";")?;
802803
}
803804
}
804805
}
805806

806-
// Write functions to create special types.
807+
// Write functions for special types.
807808
for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() {
808809
match type_key {
809810
&crate::PredeclaredType::ModfResult { size, scalar }
810811
| &crate::PredeclaredType::FrexpResult { size, scalar } => {
812+
let struct_name = &self.names[&NameKey::Type(*struct_ty)];
811813
let arg_type_name_owner;
812814
let arg_type_name = if let Some(size) = size {
813815
arg_type_name_owner = format!(
@@ -836,8 +838,6 @@ impl<'a, W: Write> Writer<'a, W> {
836838
(FREXP_FUNCTION, "frexp", other_type_name)
837839
};
838840

839-
let struct_name = &self.names[&NameKey::Type(*struct_ty)];
840-
841841
writeln!(self.out)?;
842842
if !self.options.version.supports_frexp_function()
843843
&& matches!(type_key, &crate::PredeclaredType::FrexpResult { .. })
@@ -861,7 +861,9 @@ impl<'a, W: Write> Writer<'a, W> {
861861
)?;
862862
}
863863
}
864-
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
864+
&crate::PredeclaredType::AtomicCompareExchangeWeakResult(_) => {
865+
// Handled by the general struct writing loop earlier.
866+
}
865867
}
866868
}
867869

@@ -1482,6 +1484,18 @@ impl<'a, W: Write> Writer<'a, W> {
14821484
}
14831485
}
14841486
}
1487+
1488+
for statement in func.body.iter() {
1489+
match *statement {
1490+
crate::Statement::Atomic {
1491+
fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
1492+
..
1493+
} => {
1494+
self.need_bake_expressions.insert(cmp);
1495+
}
1496+
_ => {}
1497+
}
1498+
}
14851499
}
14861500

14871501
/// Helper method used to get a name for a global
@@ -2573,33 +2587,50 @@ impl<'a, W: Write> Writer<'a, W> {
25732587
result,
25742588
} => {
25752589
write!(self.out, "{level}")?;
2576-
if let Some(result) = result {
2577-
let res_name = Baked(result).to_string();
2578-
let res_ty = ctx.resolve_type(result, &self.module.types);
2579-
self.write_value_type(res_ty)?;
2580-
write!(self.out, " {res_name} = ")?;
2581-
self.named_expressions.insert(result, res_name);
2582-
}
25832590

2584-
let fun_str = fun.to_glsl();
2585-
write!(self.out, "atomic{fun_str}(")?;
2586-
self.write_expr(pointer, ctx)?;
2587-
write!(self.out, ", ")?;
2588-
// handle the special cases
25892591
match *fun {
2590-
crate::AtomicFunction::Subtract => {
2591-
// we just wrote `InterlockedAdd`, so negate the argument
2592-
write!(self.out, "-")?;
2592+
crate::AtomicFunction::Exchange {
2593+
compare: Some(compare_expr),
2594+
} => {
2595+
let result_handle = result.expect("CompareExchange must have a result");
2596+
let res_name = Baked(result_handle).to_string();
2597+
self.write_type(ctx.info[result_handle].ty.handle().unwrap())?;
2598+
write!(self.out, " {res_name};")?;
2599+
write!(self.out, " {res_name}.old_value = atomicCompSwap(")?;
2600+
self.write_expr(pointer, ctx)?;
2601+
write!(self.out, ", ")?;
2602+
self.write_expr(compare_expr, ctx)?;
2603+
write!(self.out, ", ")?;
2604+
self.write_expr(value, ctx)?;
2605+
writeln!(self.out, ");")?;
2606+
2607+
write!(
2608+
self.out,
2609+
"{level}{res_name}.exchanged = ({res_name}.old_value == "
2610+
)?;
2611+
self.write_expr(compare_expr, ctx)?;
2612+
writeln!(self.out, ");")?;
2613+
self.named_expressions.insert(result_handle, res_name);
25932614
}
2594-
crate::AtomicFunction::Exchange { compare: Some(_) } => {
2595-
return Err(Error::Custom(
2596-
"atomic CompareExchange is not implemented".to_string(),
2597-
));
2615+
_ => {
2616+
if let Some(result) = result {
2617+
let res_name = Baked(result).to_string();
2618+
self.write_type(ctx.info[result].ty.handle().unwrap())?;
2619+
write!(self.out, " {res_name} = ")?;
2620+
self.named_expressions.insert(result, res_name);
2621+
}
2622+
let fun_str = fun.to_glsl();
2623+
write!(self.out, "atomic{fun_str}(")?;
2624+
self.write_expr(pointer, ctx)?;
2625+
write!(self.out, ", ")?;
2626+
if let crate::AtomicFunction::Subtract = *fun {
2627+
// Emulate `atomicSub` with `atomicAdd` by negating the value.
2628+
write!(self.out, "-")?;
2629+
}
2630+
self.write_expr(value, ctx)?;
2631+
writeln!(self.out, ");")?;
25982632
}
2599-
_ => {}
26002633
}
2601-
self.write_expr(value, ctx)?;
2602-
writeln!(self.out, ");")?;
26032634
}
26042635
// Stores a value into an image.
26052636
Statement::ImageAtomic {

naga/src/back/hlsl/conv.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ impl crate::AtomicFunction {
222222
Self::Min => "Min",
223223
Self::Max => "Max",
224224
Self::Exchange { compare: None } => "Exchange",
225-
Self::Exchange { .. } => "", //TODO
225+
Self::Exchange { .. } => "CompareExchange",
226226
}
227227
}
228228
}

naga/src/back/hlsl/writer.rs

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
269269
} => {
270270
self.need_bake_expressions.insert(argument);
271271
}
272+
crate::Statement::Atomic {
273+
fun: crate::AtomicFunction::Exchange { compare: Some(cmp) },
274+
..
275+
} => {
276+
self.need_bake_expressions.insert(cmp);
277+
}
272278
_ => {}
273279
}
274280
}
@@ -2358,79 +2364,78 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
23582364
result,
23592365
} => {
23602366
write!(self.out, "{level}")?;
2361-
let res_name = match result {
2362-
None => None,
2363-
Some(result) => {
2364-
let name = Baked(result).to_string();
2365-
match func_ctx.info[result].ty {
2366-
proc::TypeResolution::Handle(handle) => {
2367-
self.write_type(module, handle)?
2368-
}
2369-
proc::TypeResolution::Value(ref value) => {
2370-
self.write_value_type(module, value)?
2371-
}
2372-
};
2373-
write!(self.out, " {name}; ")?;
2374-
Some((result, name))
2375-
}
2367+
let res_var_info = if let Some(res_handle) = result {
2368+
let name = Baked(res_handle).to_string();
2369+
match func_ctx.info[res_handle].ty {
2370+
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
2371+
proc::TypeResolution::Value(ref value) => {
2372+
self.write_value_type(module, value)?
2373+
}
2374+
};
2375+
write!(self.out, " {name}; ")?;
2376+
self.named_expressions.insert(res_handle, name.clone());
2377+
Some((res_handle, name))
2378+
} else {
2379+
None
23762380
};
2377-
2378-
// Validation ensures that `pointer` has a `Pointer` type.
23792381
let pointer_space = func_ctx
23802382
.resolve_type(pointer, &module.types)
23812383
.pointer_space()
23822384
.unwrap();
2383-
23842385
let fun_str = fun.to_hlsl_suffix();
2386+
let compare_expr = match *fun {
2387+
crate::AtomicFunction::Exchange { compare: Some(cmp) } => Some(cmp),
2388+
_ => None,
2389+
};
23852390
match pointer_space {
23862391
crate::AddressSpace::WorkGroup => {
23872392
write!(self.out, "Interlocked{fun_str}(")?;
23882393
self.write_expr(module, pointer, func_ctx)?;
2394+
self.emit_hlsl_atomic_tail(
2395+
module,
2396+
func_ctx,
2397+
fun,
2398+
compare_expr,
2399+
value,
2400+
&res_var_info,
2401+
)?;
23892402
}
23902403
crate::AddressSpace::Storage { .. } => {
23912404
let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
2392-
// The call to `self.write_storage_address` wants
2393-
// mutable access to all of `self`, so temporarily take
2394-
// ownership of our reusable access chain buffer.
2395-
let chain = mem::take(&mut self.temp_access_chain);
23962405
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
23972406
let width = match func_ctx.resolve_type(value, &module.types) {
23982407
&TypeInner::Scalar(Scalar { width: 8, .. }) => "64",
23992408
_ => "",
24002409
};
24012410
write!(self.out, "{var_name}.Interlocked{fun_str}{width}(")?;
2411+
let chain = mem::take(&mut self.temp_access_chain);
24022412
self.write_storage_address(module, &chain, func_ctx)?;
24032413
self.temp_access_chain = chain;
2414+
self.emit_hlsl_atomic_tail(
2415+
module,
2416+
func_ctx,
2417+
fun,
2418+
compare_expr,
2419+
value,
2420+
&res_var_info,
2421+
)?;
24042422
}
24052423
ref other => {
24062424
return Err(Error::Custom(format!(
24072425
"invalid address space {other:?} for atomic statement"
24082426
)))
24092427
}
24102428
}
2411-
write!(self.out, ", ")?;
2412-
// handle the special cases
2413-
match *fun {
2414-
crate::AtomicFunction::Subtract => {
2415-
// we just wrote `InterlockedAdd`, so negate the argument
2416-
write!(self.out, "-")?;
2417-
}
2418-
crate::AtomicFunction::Exchange { compare: Some(_) } => {
2419-
return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
2429+
if let Some(cmp) = compare_expr {
2430+
if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
2431+
write!(
2432+
self.out,
2433+
"{level}{res_name}.exchanged = ({res_name}.old_value == "
2434+
)?;
2435+
self.write_expr(module, cmp, func_ctx)?;
2436+
writeln!(self.out, ");")?;
24202437
}
2421-
_ => {}
2422-
}
2423-
self.write_expr(module, value, func_ctx)?;
2424-
2425-
// The `original_value` out parameter is optional for all the
2426-
// `Interlocked` functions we generate other than
2427-
// `InterlockedExchange`.
2428-
if let Some((result, name)) = res_name {
2429-
write!(self.out, ", {name}")?;
2430-
self.named_expressions.insert(result, name);
24312438
}
2432-
2433-
writeln!(self.out, ");")?;
24342439
}
24352440
Statement::ImageAtomic {
24362441
image,
@@ -4312,6 +4317,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
43124317
}
43134318
Ok(())
43144319
}
4320+
4321+
/// Helper to emit the shared tail of an HLSL atomic call (arguments, value, result)
4322+
fn emit_hlsl_atomic_tail(
4323+
&mut self,
4324+
module: &Module,
4325+
func_ctx: &back::FunctionCtx<'_>,
4326+
fun: &crate::AtomicFunction,
4327+
compare_expr: Option<Handle<crate::Expression>>,
4328+
value: Handle<crate::Expression>,
4329+
res_var_info: &Option<(Handle<crate::Expression>, String)>,
4330+
) -> BackendResult {
4331+
if let Some(cmp) = compare_expr {
4332+
write!(self.out, ", ")?;
4333+
self.write_expr(module, cmp, func_ctx)?;
4334+
}
4335+
write!(self.out, ", ")?;
4336+
if let crate::AtomicFunction::Subtract = *fun {
4337+
// we just wrote `InterlockedAdd`, so negate the argument
4338+
write!(self.out, "-")?;
4339+
}
4340+
self.write_expr(module, value, func_ctx)?;
4341+
if let Some(&(_res_handle, ref res_name)) = res_var_info.as_ref() {
4342+
write!(self.out, ", ")?;
4343+
if compare_expr.is_some() {
4344+
write!(self.out, "{res_name}.old_value")?;
4345+
} else {
4346+
write!(self.out, "{res_name}")?;
4347+
}
4348+
}
4349+
writeln!(self.out, ");")?;
4350+
Ok(())
4351+
}
43154352
}
43164353

43174354
pub(super) struct MatrixType {

naga/tests/in/wgsl/atomicCompareExchange-int64.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
god_mode = true
2-
targets = "SPIRV | WGSL"
2+
targets = "SPIRV | HLSL | WGSL"
33

44
[hlsl]
5+
shader_model = "V6_6"
56
fake_missing_bindings = true
67
push_constants_target = { register = 0, space = 0 }
78
restrict_indexing = true
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
targets = "SPIRV | METAL | WGSL"
1+
targets = "SPIRV | METAL | GLSL | HLSL | WGSL"

naga/tests/in/wgsl/atomicOps-int64.wgsl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,12 @@ fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
129129
atomicExchange(&workgroup_struct.atomic_scalar, 1lu);
130130
atomicExchange(&workgroup_struct.atomic_arr[1], 1li);
131131

132-
// // TODO: https://github.com/gpuweb/gpuweb/issues/2021
133-
// atomicCompareExchangeWeak(&storage_atomic_scalar, 1lu);
134-
// atomicCompareExchangeWeak(&storage_atomic_arr[1], 1li);
135-
// atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1lu);
136-
// atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1li);
137-
// atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1lu);
138-
// atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1li);
139-
// atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1lu);
140-
// atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1li);
132+
let cas_res_0 = atomicCompareExchangeWeak(&storage_atomic_scalar, 1lu, 2lu);
133+
let cas_res_1 = atomicCompareExchangeWeak(&storage_atomic_arr[1], 1li, 2li);
134+
let cas_res_2 = atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1lu, 2lu);
135+
let cas_res_3 = atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1li, 2li);
136+
let cas_res_4 = atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1lu, 2lu);
137+
let cas_res_5 = atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1li, 2li);
138+
let cas_res_6 = atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1lu, 2lu);
139+
let cas_res_7 = atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1li, 2li);
141140
}

naga/tests/in/wgsl/atomicOps.wgsl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,12 @@ fn cs_main(@builtin(local_invocation_id) id: vec3<u32>) {
129129
atomicExchange(&workgroup_struct.atomic_scalar, 1u);
130130
atomicExchange(&workgroup_struct.atomic_arr[1], 1i);
131131

132-
// // TODO: https://github.com/gpuweb/gpuweb/issues/2021
133-
// atomicCompareExchangeWeak(&storage_atomic_scalar, 1u);
134-
// atomicCompareExchangeWeak(&storage_atomic_arr[1], 1i);
135-
// atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1u);
136-
// atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1i);
137-
// atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1u);
138-
// atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1i);
139-
// atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1u);
140-
// atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1i);
132+
let cas_res_0 = atomicCompareExchangeWeak(&storage_atomic_scalar, 1u, 2u);
133+
let cas_res_1 = atomicCompareExchangeWeak(&storage_atomic_arr[1], 1i, 2i);
134+
let cas_res_2 = atomicCompareExchangeWeak(&storage_struct.atomic_scalar, 1u, 2u);
135+
let cas_res_3 = atomicCompareExchangeWeak(&storage_struct.atomic_arr[1], 1i, 2i);
136+
let cas_res_4 = atomicCompareExchangeWeak(&workgroup_atomic_scalar, 1u, 2u);
137+
let cas_res_5 = atomicCompareExchangeWeak(&workgroup_atomic_arr[1], 1i, 2i);
138+
let cas_res_6 = atomicCompareExchangeWeak(&workgroup_struct.atomic_scalar, 1u, 2u);
139+
let cas_res_7 = atomicCompareExchangeWeak(&workgroup_struct.atomic_arr[1], 1i, 2i);
141140
}

0 commit comments

Comments
 (0)