Skip to content

Commit 8ba5e86

Browse files
committed
subgroup: drop the non_uniform from all subgroup functions, matching glsl
1 parent 64c0e68 commit 8ba5e86

23 files changed

+166
-209
lines changed

crates/spirv-std/src/arch/subgroup.rs

Lines changed: 50 additions & 87 deletions
Large diffs are not rendered by default.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_ballot::subgroup_ballot
4+
5+
use spirv_std::spirv;
6+
7+
unsafe fn subgroup_ballot(predicate: bool) -> bool {
8+
let ballot = spirv_std::arch::subgroup_ballot(predicate);
9+
spirv_std::arch::subgroup_inverse_ballot(ballot)
10+
}
11+
12+
#[spirv(compute(threads(1, 1, 1)))]
13+
pub fn main() {
14+
unsafe {
15+
subgroup_ballot(true);
16+
}
17+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_ballot_bit_count::subgroup_ballot_bit_count
4+
5+
use spirv_std::arch::{GroupOperation, SubgroupMask};
6+
use spirv_std::spirv;
7+
8+
unsafe fn subgroup_ballot_bit_count(ballot: SubgroupMask) -> u32 {
9+
spirv_std::arch::subgroup_ballot_bit_count::<{ GroupOperation::Reduce as u32 }>(ballot)
10+
}
11+
12+
#[spirv(compute(threads(1, 1, 1)))]
13+
pub fn main() {
14+
unsafe {
15+
subgroup_ballot_bit_count(spirv_std::arch::subgroup_ballot(true));
16+
}
17+
}

tests/ui/arch/subgroup/subgroup_non_uniform_ballot_bit_count.stderr renamed to tests/ui/arch/subgroup/subgroup_ballot_bit_count.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
%1 = OpFunction %2 None %3
22
%4 = OpFunctionParameter %5
33
%6 = OpLabel
4-
OpLine %7 496 8
4+
OpLine %7 491 8
55
%8 = OpGroupNonUniformBallotBitCount %2 %9 Reduce %4
66
OpNoLine
77
OpReturnValue %8
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformBallot,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_broadcast_first::subgroup_broadcast_first
4+
5+
use glam::Vec3;
6+
use spirv_std::spirv;
7+
8+
unsafe fn subgroup_broadcast_first(vec: Vec3) -> Vec3 {
9+
spirv_std::arch::subgroup_broadcast_first::<Vec3>(vec)
10+
}
11+
12+
#[spirv(compute(threads(1, 1, 1)))]
13+
pub fn main() {
14+
unsafe {
15+
subgroup_broadcast_first(Vec3::new(1., 2., 3.));
16+
}
17+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+GroupNonUniform,+ext:SPV_KHR_vulkan_memory_model
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_elect::subgroup_elect
4+
5+
use spirv_std::spirv;
6+
7+
unsafe fn subgroup_elect() -> bool {
8+
spirv_std::arch::subgroup_elect()
9+
}
10+
11+
#[spirv(compute(threads(1, 1, 1)))]
12+
pub fn main() {
13+
unsafe {
14+
subgroup_elect();
15+
}
16+
}
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
// build-pass
22
// compile-flags: -C target-feature=+GroupNonUniform,+GroupNonUniformArithmetic,+GroupNonUniformClustered,+ext:SPV_KHR_vulkan_memory_model
3-
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_non_uniform_i_add_clustered::subgroup_non_uniform_i_add_clustered
3+
// compile-flags: -C llvm-args=--disassemble-fn=subgroup_i_add_clustered::subgroup_i_add_clustered
44

55
use glam::UVec3;
66
use spirv_std::arch::{GroupOperation, SubgroupMask};
77
use spirv_std::spirv;
88

9-
unsafe fn subgroup_non_uniform_i_add_clustered(value: u32) -> u32 {
10-
spirv_std::arch::subgroup_non_uniform_i_add_clustered::<8, _>(value)
9+
unsafe fn subgroup_i_add_clustered(value: u32) -> u32 {
10+
spirv_std::arch::subgroup_i_add_clustered::<8, _>(value)
1111
}
1212

1313
#[spirv(compute(threads(32, 1, 1)))]
1414
pub fn main(#[spirv(local_invocation_id)] local_invocation_id: UVec3) {
1515
unsafe {
16-
subgroup_non_uniform_i_add_clustered(local_invocation_id.x);
16+
subgroup_i_add_clustered(local_invocation_id.x);
1717
}
1818
}

0 commit comments

Comments
 (0)