Skip to content

Commit 31a0be3

Browse files
dvc94cheddyb
authored andcommitted
Add reduce example.
1 parent f638590 commit 31a0be3

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ members = [
55
"examples/runners/ash",
66
"examples/runners/wgpu",
77
"examples/runners/wgpu/builder",
8+
"examples/shaders/reduce",
89
"examples/shaders/sky-shader",
910
"examples/shaders/simplest-shader",
1011
"examples/shaders/compute-shader",

examples/shaders/reduce/Cargo.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "reduce"
3+
version = "0.4.0-alpha.12"
4+
authors = ["Embark <opensource@embark-studios.com>"]
5+
edition = "2018"
6+
license = "MIT OR Apache-2.0"
7+
publish = false
8+
9+
[lib]
10+
crate-type = ["dylib", "lib"]
11+
12+
[dependencies]
13+
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }

examples/shaders/reduce/src/lib.rs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#![cfg_attr(
2+
target_arch = "spirv",
3+
no_std,
4+
feature(register_attr),
5+
register_attr(spirv)
6+
)]
7+
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
8+
#![deny(warnings)]
9+
use spirv_std::glam::UVec3;
10+
#[cfg(not(target_arch = "spirv"))]
11+
use spirv_std::macros::spirv;
12+
#[cfg(target_arch = "spirv")]
13+
use spirv_std::memory::Scope;
14+
15+
#[doc(alias = "OpGroupNonUniformIAdd")]
16+
#[cfg(target_arch = "spirv")]
17+
#[inline]
18+
pub unsafe fn subgroup_add(value: u32) -> u32 {
19+
const EXECUTION: u32 = Scope::Subgroup as _;
20+
let mut result = 0;
21+
asm! {
22+
"%u32 = OpTypeInt 32 0",
23+
"%execution = OpConstant %u32 {execution}",
24+
"%result = OpGroupNonUniformIAdd _ %execution Reduce {value}",
25+
"OpStore {result} %result",
26+
execution = const EXECUTION,
27+
value = in(reg) value,
28+
result = in(reg) &mut result,
29+
}
30+
result
31+
}
32+
33+
#[cfg(not(target_arch = "spirv"))]
34+
pub unsafe fn subgroup_add(_value: u32) -> u32 {
35+
panic!()
36+
}
37+
38+
#[spirv(compute(threads(256)))]
39+
pub fn main(
40+
#[spirv(global_invocation_id)] global_invocation_id: UVec3,
41+
#[spirv(local_invocation_id)] local_invocation_id: UVec3,
42+
#[spirv(subgroup_local_invocation_id)] subgroup_local_invocation_id: u32,
43+
#[spirv(workgroup_id)] workgroup_id: UVec3,
44+
#[spirv(subgroup_id)] subgroup_id: u32,
45+
#[spirv(num_subgroups)] num_subgroups: u32,
46+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] input: &[u32],
47+
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] output: &mut [u32],
48+
#[spirv(workgroup)] shared: &mut [u32; 256],
49+
) {
50+
let global_invocation_id_x = global_invocation_id.x as usize;
51+
let local_invocation_id_x = local_invocation_id.x as usize;
52+
let workgroup_id_x = workgroup_id.x as usize;
53+
54+
let mut sum = 0;
55+
if global_invocation_id_x < input.len() {
56+
sum = input[global_invocation_id_x];
57+
}
58+
sum = unsafe { subgroup_add(sum) };
59+
if subgroup_local_invocation_id == 0 {
60+
shared[subgroup_id as usize] = sum;
61+
}
62+
unsafe { spirv_std::arch::workgroup_memory_barrier_with_group_sync() };
63+
let mut sum = 0;
64+
if subgroup_id == 0 {
65+
if subgroup_local_invocation_id < num_subgroups {
66+
sum = shared[subgroup_local_invocation_id as usize];
67+
}
68+
sum = unsafe { subgroup_add(sum) };
69+
}
70+
if local_invocation_id_x == 0 {
71+
output[workgroup_id_x] = sum;
72+
}
73+
}

0 commit comments

Comments
 (0)