Skip to content

Commit 198762e

Browse files
feat: Add 32-bit floating-point atomics (SHADER_FLOAT32_ATOMIC) (#6234)
* feat: Add 32-bit floating-point atomics * Current supported platforms: Metal * Platforms to support in the future: Vulkan Related issues or PRs: * #1020 * Add changelog * Edit changelog * feat: Add 32-bit float atomics support for Vulkan (SPIR-V shaders) * atomicSub for f32 in the previous commits is removed. * Update test * chore: doc type link * refactor: Revise float atomics on msl and spv * Make branches tidy * Also revise old codes * Ensure the implementations are supported by Metal and Vulkan backends * refactor: Renaming flt32 atomics to float32 atomics * chore: Add link to Vulkan feature * fix: cargo fmt * chore: hack comment * Revert changelog * Fix: Cargo advisory * Update wgpu-hal/src/metal/adapter.rs Co-authored-by: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> * Update naga/src/lib.rs Co-authored-by: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> * Adjust feature flag position --------- Co-authored-by: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com>
1 parent 6e2394b commit 198762e

File tree

20 files changed

+633
-148
lines changed

20 files changed

+633
-148
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ By @wumpf in [#6849](https://github.com/gfx-rs/wgpu/pull/6849).
163163
- Allow for statically linking DXC rather than including separate `.dll` files. By @DouglasDwyer in [#6574](https://github.com/gfx-rs/wgpu/pull/6574).
164164
- `DeviceType` and `AdapterInfo` now impl `Hash` by @cwfitzgerald in [#6868](https://github.com/gfx-rs/wgpu/pull/6868)
165165

166+
##### Vulkan
167+
168+
- Allow using some 32-bit floating-point atomic operations (load, store, add, sub, exchange) in shaders. It requires the extension `VK_EXT_shader_atomic_float`. By @AsherJingkongChen in [#6234](https://github.com/gfx-rs/wgpu/pull/6234).
169+
170+
##### Metal
171+
172+
- Allow using some 32-bit floating-point atomic operations (load, store, add, sub, exchange) in shaders. It requires Metal 3.0+ with Apple 7, 8, 9 or Mac 2. By @AsherJingkongChen in [#6234](https://github.com/gfx-rs/wgpu/pull/6234).
173+
166174
#### Changes
167175

168176
##### Naga

naga/src/back/spv/block.rs

Lines changed: 120 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2730,62 +2730,115 @@ impl BlockContext<'_> {
27302730
let value_id = self.cached[value];
27312731
let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
27322732

2733+
let crate::TypeInner::Scalar(scalar) = *value_inner else {
2734+
return Err(Error::FeatureNotImplemented(
2735+
"Atomics with non-scalar values",
2736+
));
2737+
};
2738+
27332739
let instruction = match *fun {
2734-
crate::AtomicFunction::Add => Instruction::atomic_binary(
2735-
spirv::Op::AtomicIAdd,
2736-
result_type_id,
2737-
id,
2738-
pointer_id,
2739-
scope_constant_id,
2740-
semantics_id,
2741-
value_id,
2742-
),
2743-
crate::AtomicFunction::Subtract => Instruction::atomic_binary(
2744-
spirv::Op::AtomicISub,
2745-
result_type_id,
2746-
id,
2747-
pointer_id,
2748-
scope_constant_id,
2749-
semantics_id,
2750-
value_id,
2751-
),
2752-
crate::AtomicFunction::And => Instruction::atomic_binary(
2753-
spirv::Op::AtomicAnd,
2754-
result_type_id,
2755-
id,
2756-
pointer_id,
2757-
scope_constant_id,
2758-
semantics_id,
2759-
value_id,
2760-
),
2761-
crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
2762-
spirv::Op::AtomicOr,
2763-
result_type_id,
2764-
id,
2765-
pointer_id,
2766-
scope_constant_id,
2767-
semantics_id,
2768-
value_id,
2769-
),
2770-
crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
2771-
spirv::Op::AtomicXor,
2772-
result_type_id,
2773-
id,
2774-
pointer_id,
2775-
scope_constant_id,
2776-
semantics_id,
2777-
value_id,
2778-
),
2740+
crate::AtomicFunction::Add => {
2741+
let spirv_op = match scalar.kind {
2742+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2743+
spirv::Op::AtomicIAdd
2744+
}
2745+
crate::ScalarKind::Float => spirv::Op::AtomicFAddEXT,
2746+
_ => unimplemented!(),
2747+
};
2748+
Instruction::atomic_binary(
2749+
spirv_op,
2750+
result_type_id,
2751+
id,
2752+
pointer_id,
2753+
scope_constant_id,
2754+
semantics_id,
2755+
value_id,
2756+
)
2757+
}
2758+
crate::AtomicFunction::Subtract => {
2759+
let (spirv_op, value_id) = match scalar.kind {
2760+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2761+
(spirv::Op::AtomicISub, value_id)
2762+
}
2763+
crate::ScalarKind::Float => {
2764+
// HACK: SPIR-V doesn't have a atomic subtraction,
2765+
// so we add the negated value instead.
2766+
let neg_result_id = self.gen_id();
2767+
block.body.push(Instruction::unary(
2768+
spirv::Op::FNegate,
2769+
result_type_id,
2770+
neg_result_id,
2771+
value_id,
2772+
));
2773+
(spirv::Op::AtomicFAddEXT, neg_result_id)
2774+
}
2775+
_ => unimplemented!(),
2776+
};
2777+
Instruction::atomic_binary(
2778+
spirv_op,
2779+
result_type_id,
2780+
id,
2781+
pointer_id,
2782+
scope_constant_id,
2783+
semantics_id,
2784+
value_id,
2785+
)
2786+
}
2787+
crate::AtomicFunction::And => {
2788+
let spirv_op = match scalar.kind {
2789+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2790+
spirv::Op::AtomicAnd
2791+
}
2792+
_ => unimplemented!(),
2793+
};
2794+
Instruction::atomic_binary(
2795+
spirv_op,
2796+
result_type_id,
2797+
id,
2798+
pointer_id,
2799+
scope_constant_id,
2800+
semantics_id,
2801+
value_id,
2802+
)
2803+
}
2804+
crate::AtomicFunction::InclusiveOr => {
2805+
let spirv_op = match scalar.kind {
2806+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2807+
spirv::Op::AtomicOr
2808+
}
2809+
_ => unimplemented!(),
2810+
};
2811+
Instruction::atomic_binary(
2812+
spirv_op,
2813+
result_type_id,
2814+
id,
2815+
pointer_id,
2816+
scope_constant_id,
2817+
semantics_id,
2818+
value_id,
2819+
)
2820+
}
2821+
crate::AtomicFunction::ExclusiveOr => {
2822+
let spirv_op = match scalar.kind {
2823+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2824+
spirv::Op::AtomicXor
2825+
}
2826+
_ => unimplemented!(),
2827+
};
2828+
Instruction::atomic_binary(
2829+
spirv_op,
2830+
result_type_id,
2831+
id,
2832+
pointer_id,
2833+
scope_constant_id,
2834+
semantics_id,
2835+
value_id,
2836+
)
2837+
}
27792838
crate::AtomicFunction::Min => {
2780-
let spirv_op = match *value_inner {
2781-
crate::TypeInner::Scalar(crate::Scalar {
2782-
kind: crate::ScalarKind::Sint,
2783-
width: _,
2784-
}) => spirv::Op::AtomicSMin,
2785-
crate::TypeInner::Scalar(crate::Scalar {
2786-
kind: crate::ScalarKind::Uint,
2787-
width: _,
2788-
}) => spirv::Op::AtomicUMin,
2839+
let spirv_op = match scalar.kind {
2840+
crate::ScalarKind::Sint => spirv::Op::AtomicSMin,
2841+
crate::ScalarKind::Uint => spirv::Op::AtomicUMin,
27892842
_ => unimplemented!(),
27902843
};
27912844
Instruction::atomic_binary(
@@ -2799,15 +2852,9 @@ impl BlockContext<'_> {
27992852
)
28002853
}
28012854
crate::AtomicFunction::Max => {
2802-
let spirv_op = match *value_inner {
2803-
crate::TypeInner::Scalar(crate::Scalar {
2804-
kind: crate::ScalarKind::Sint,
2805-
width: _,
2806-
}) => spirv::Op::AtomicSMax,
2807-
crate::TypeInner::Scalar(crate::Scalar {
2808-
kind: crate::ScalarKind::Uint,
2809-
width: _,
2810-
}) => spirv::Op::AtomicUMax,
2855+
let spirv_op = match scalar.kind {
2856+
crate::ScalarKind::Sint => spirv::Op::AtomicSMax,
2857+
crate::ScalarKind::Uint => spirv::Op::AtomicUMax,
28112858
_ => unimplemented!(),
28122859
};
28132860
Instruction::atomic_binary(
@@ -2832,20 +2879,21 @@ impl BlockContext<'_> {
28322879
)
28332880
}
28342881
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
2835-
let scalar_type_id = match *value_inner {
2836-
crate::TypeInner::Scalar(scalar) => {
2837-
self.get_type_id(LookupType::Local(LocalType::Numeric(
2838-
NumericType::Scalar(scalar),
2839-
)))
2840-
}
2841-
_ => unimplemented!(),
2842-
};
2882+
let scalar_type_id = self.get_type_id(LookupType::Local(
2883+
LocalType::Numeric(NumericType::Scalar(scalar)),
2884+
));
28432885
let bool_type_id = self.get_type_id(LookupType::Local(
28442886
LocalType::Numeric(NumericType::Scalar(crate::Scalar::BOOL)),
28452887
));
28462888

28472889
let cas_result_id = self.gen_id();
28482890
let equality_result_id = self.gen_id();
2891+
let equality_operator = match scalar.kind {
2892+
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
2893+
spirv::Op::IEqual
2894+
}
2895+
_ => unimplemented!(),
2896+
};
28492897
let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
28502898
cas_instr.set_type(scalar_type_id);
28512899
cas_instr.set_result(cas_result_id);
@@ -2857,7 +2905,7 @@ impl BlockContext<'_> {
28572905
cas_instr.add_operand(self.cached[cmp]);
28582906
block.body.push(cas_instr);
28592907
block.body.push(Instruction::binary(
2860-
spirv::Op::IEqual,
2908+
equality_operator,
28612909
bool_type_id,
28622910
equality_result_id,
28632911
cas_result_id,

naga/src/back/spv/writer.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,16 @@ impl Writer {
857857
crate::TypeInner::Atomic(crate::Scalar { width: 8, kind: _ }) => {
858858
self.require_any("64 bit integer atomics", &[spirv::Capability::Int64Atomics])?;
859859
}
860+
crate::TypeInner::Atomic(crate::Scalar {
861+
width: 4,
862+
kind: crate::ScalarKind::Float,
863+
}) => {
864+
self.require_any(
865+
"32 bit floating-point atomics",
866+
&[spirv::Capability::AtomicFloat32AddEXT],
867+
)?;
868+
self.use_extension("SPV_EXT_shader_atomic_float_add");
869+
}
860870
_ => {}
861871
}
862872
Ok(())

naga/src/front/spv/mod.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
6767
spirv::Capability::Int64,
6868
spirv::Capability::Int64Atomics,
6969
spirv::Capability::Float16,
70+
spirv::Capability::AtomicFloat32AddEXT,
7071
spirv::Capability::Float64,
7172
spirv::Capability::Geometry,
7273
spirv::Capability::MultiView,
@@ -78,6 +79,7 @@ pub const SUPPORTED_EXTENSIONS: &[&str] = &[
7879
"SPV_KHR_storage_buffer_storage_class",
7980
"SPV_KHR_vulkan_memory_model",
8081
"SPV_KHR_multiview",
82+
"SPV_EXT_shader_atomic_float_add",
8183
];
8284
pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"];
8385

@@ -4339,7 +4341,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
43394341
| Op::AtomicUMax
43404342
| Op::AtomicAnd
43414343
| Op::AtomicOr
4342-
| Op::AtomicXor => self.parse_atomic_expr_with_value(
4344+
| Op::AtomicXor
4345+
| Op::AtomicFAddEXT => self.parse_atomic_expr_with_value(
43434346
inst,
43444347
&mut emitter,
43454348
ctx,
@@ -4348,15 +4351,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
43484351
body_idx,
43494352
match inst.op {
43504353
Op::AtomicExchange => crate::AtomicFunction::Exchange { compare: None },
4351-
Op::AtomicIAdd => crate::AtomicFunction::Add,
4354+
Op::AtomicIAdd | Op::AtomicFAddEXT => crate::AtomicFunction::Add,
43524355
Op::AtomicISub => crate::AtomicFunction::Subtract,
43534356
Op::AtomicSMin => crate::AtomicFunction::Min,
43544357
Op::AtomicUMin => crate::AtomicFunction::Min,
43554358
Op::AtomicSMax => crate::AtomicFunction::Max,
43564359
Op::AtomicUMax => crate::AtomicFunction::Max,
43574360
Op::AtomicAnd => crate::AtomicFunction::And,
43584361
Op::AtomicOr => crate::AtomicFunction::InclusiveOr,
4359-
_ => crate::AtomicFunction::ExclusiveOr,
4362+
Op::AtomicXor => crate::AtomicFunction::ExclusiveOr,
4363+
_ => unreachable!(),
43604364
},
43614365
)?,
43624366

naga/src/lib.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1949,14 +1949,18 @@ pub enum Statement {
19491949
/// If [`SHADER_INT64_ATOMIC_MIN_MAX`] or [`SHADER_INT64_ATOMIC_ALL_OPS`] are
19501950
/// enabled, this may also be [`I64`] or [`U64`].
19511951
///
1952+
/// If [`SHADER_FLOAT32_ATOMIC`] is enabled, this may be [`F32`].
1953+
///
19521954
/// [`Pointer`]: TypeInner::Pointer
19531955
/// [`Atomic`]: TypeInner::Atomic
19541956
/// [`I32`]: Scalar::I32
19551957
/// [`U32`]: Scalar::U32
19561958
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
19571959
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
1960+
/// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC
19581961
/// [`I64`]: Scalar::I64
19591962
/// [`U64`]: Scalar::U64
1963+
/// [`F32`]: Scalar::F32
19601964
pointer: Handle<Expression>,
19611965

19621966
/// Function to run on the atomic value.
@@ -1967,14 +1971,24 @@ pub enum Statement {
19671971
/// value here.
19681972
///
19691973
/// - The [`SHADER_INT64_ATOMIC_MIN_MAX`] capability allows
1970-
/// [`AtomicFunction::Min`] and [`AtomicFunction::Max`] here.
1974+
/// [`AtomicFunction::Min`] and [`AtomicFunction::Max`]
1975+
/// in the [`Storage`] address space here.
19711976
///
19721977
/// - If neither of those capabilities are present, then 64-bit scalar
19731978
/// atomics are not allowed.
19741979
///
1980+
/// If [`pointer`] refers to a 32-bit floating-point atomic value, then:
1981+
///
1982+
/// - The [`SHADER_FLOAT32_ATOMIC`] capability allows [`AtomicFunction::Add`],
1983+
/// [`AtomicFunction::Subtract`], and [`AtomicFunction::Exchange { compare: None }`]
1984+
/// in the [`Storage`] address space here.
1985+
///
1986+
/// [`AtomicFunction::Exchange { compare: None }`]: AtomicFunction::Exchange
19751987
/// [`pointer`]: Statement::Atomic::pointer
1988+
/// [`Storage`]: AddressSpace::Storage
19761989
/// [`SHADER_INT64_ATOMIC_MIN_MAX`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_MIN_MAX
19771990
/// [`SHADER_INT64_ATOMIC_ALL_OPS`]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS
1991+
/// [`SHADER_FLOAT32_ATOMIC`]: crate::valid::Capabilities::SHADER_FLOAT32_ATOMIC
19781992
fun: AtomicFunction,
19791993

19801994
/// Value to use in the function.

0 commit comments

Comments
 (0)