Skip to content

Commit 0594d07

Browse files
committed
Atomic coherency test
This overwrites the compute-shader-hello example to be a test of atomic coherency. My understanding is that with the barriers it should run with 0 failures, even in strategy 0. With strategy 1 (atomicOr) as a workaround, it seems to be working.
1 parent 89e76e7 commit 0594d07

File tree

4 files changed

+153
-38
lines changed

4 files changed

+153
-38
lines changed

compute-shader-hello/Cargo.lock

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

compute-shader-hello/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ edition = "2018"
77
resolver = "2"
88

99
[dependencies]
10-
wgpu = "0.11.0"
10+
wgpu = { version = "0.11.0", features = ["spirv"] }
1111
env_logger = "0.8"
1212
pollster = "0.2"
1313
bytemuck = { version = "1.7", features = ["derive"] }

compute-shader-hello/src/main.rs

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,25 @@ use wgpu::util::DeviceExt;
2222

2323
use bytemuck;
2424

25+
// A strategy of 0 is just atomic loads.
26+
// A strategy of 1 replaces the flag load with an atomicOr.
27+
const STRATEGY: u32 = 0;
28+
29+
const USE_SPIRV: bool = false;
30+
2531
async fn run() {
2632
let instance = wgpu::Instance::new(wgpu::Backends::PRIMARY);
2733
let adapter = instance.request_adapter(&Default::default()).await.unwrap();
2834
let features = adapter.features();
35+
let mut feature_mask = wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::CLEAR_COMMANDS;
36+
if USE_SPIRV {
37+
feature_mask |= wgpu::Features::SPIRV_SHADER_PASSTHROUGH;
38+
}
2939
let (device, queue) = adapter
3040
.request_device(
3141
&wgpu::DeviceDescriptor {
3242
label: None,
33-
features: features & wgpu::Features::TIMESTAMP_QUERY,
43+
features: features & feature_mask,
3444
limits: Default::default(),
3545
},
3646
None,
@@ -48,26 +58,33 @@ async fn run() {
4858
};
4959

5060
let start_instant = Instant::now();
51-
let cs_module = device.create_shader_module(&wgpu::ShaderModuleDescriptor {
52-
label: None,
53-
//source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
54-
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
55-
});
61+
let cs_module = if USE_SPIRV {
62+
let shader_src: &[u32] = bytemuck::cast_slice(include_bytes!("shader.spv"));
63+
unsafe {
64+
device.create_shader_module_spirv(&wgpu::ShaderModuleDescriptorSpirV {
65+
label: None,
66+
source: std::borrow::Cow::Owned(shader_src.into()),
67+
})
68+
}
69+
} else {
70+
device.create_shader_module(&wgpu::ShaderModuleDescriptor {
71+
label: None,
72+
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
73+
})
74+
};
75+
76+
5677
println!("shader compilation {:?}", start_instant.elapsed());
57-
let input_f = &[1.0f32, 2.0f32];
58-
let input : &[u8] = bytemuck::bytes_of(input_f);
59-
let input_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
78+
let data_buf = device.create_buffer(&wgpu::BufferDescriptor {
6079
label: None,
61-
contents: input,
62-
usage: wgpu::BufferUsages::STORAGE
63-
| wgpu::BufferUsages::COPY_DST
64-
| wgpu::BufferUsages::COPY_SRC,
80+
size: 0x80000,
81+
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
82+
mapped_at_creation: false,
6583
});
66-
let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
84+
let config_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
6785
label: None,
68-
size: input.len() as u64,
69-
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
70-
mapped_at_creation: false,
86+
contents: bytemuck::bytes_of(&[STRATEGY, 0]),
87+
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::MAP_READ,
7188
});
7289
// This works if the buffer is initialized, otherwise reads all 0, for some reason.
7390
let query_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
@@ -76,62 +93,98 @@ async fn run() {
7693
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
7794
});
7895

96+
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
97+
label: None,
98+
entries: &[wgpu::BindGroupLayoutEntry {
99+
binding: 0,
100+
visibility: wgpu::ShaderStages::COMPUTE,
101+
ty: wgpu::BindingType::Buffer {
102+
ty: wgpu::BufferBindingType::Storage { read_only: false },
103+
has_dynamic_offset: false,
104+
min_binding_size: None,
105+
},
106+
count: None,
107+
},
108+
wgpu::BindGroupLayoutEntry {
109+
binding: 1,
110+
visibility: wgpu::ShaderStages::COMPUTE,
111+
ty: wgpu::BindingType::Buffer {
112+
ty: wgpu::BufferBindingType::Storage { read_only: false },
113+
has_dynamic_offset: false,
114+
min_binding_size: None,
115+
},
116+
count: None,
117+
}],
118+
});
119+
let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
120+
label: None,
121+
bind_group_layouts: &[&bind_group_layout],
122+
push_constant_ranges: &[],
123+
});
79124
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
80125
label: None,
81-
layout: None,
126+
layout: Some(&compute_pipeline_layout),
82127
module: &cs_module,
83128
entry_point: "main",
84129
});
85130

86-
let bind_group_layout = pipeline.get_bind_group_layout(0);
87131
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
88132
label: None,
89133
layout: &bind_group_layout,
90-
entries: &[wgpu::BindGroupEntry {
91-
binding: 0,
92-
resource: input_buf.as_entire_binding(),
93-
}],
134+
entries: &[
135+
wgpu::BindGroupEntry {
136+
binding: 0,
137+
resource: data_buf.as_entire_binding(),
138+
},
139+
wgpu::BindGroupEntry {
140+
binding: 1,
141+
resource: config_buf.as_entire_binding(),
142+
},
143+
],
94144
});
95145

96146
let mut encoder = device.create_command_encoder(&Default::default());
97147
if let Some(query_set) = &query_set {
98148
encoder.write_timestamp(query_set, 0);
99149
}
150+
encoder.clear_buffer(&data_buf, 0, None);
100151
{
101152
let mut cpass = encoder.begin_compute_pass(&Default::default());
102153
cpass.set_pipeline(&pipeline);
103154
cpass.set_bind_group(0, &bind_group, &[]);
104-
cpass.dispatch(input_f.len() as u32, 1, 1);
155+
cpass.dispatch(256, 1, 1);
105156
}
106157
if let Some(query_set) = &query_set {
107158
encoder.write_timestamp(query_set, 1);
108159
}
109-
encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64);
160+
//encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64);
110161
if let Some(query_set) = &query_set {
111162
encoder.resolve_query_set(query_set, 0..2, &query_buf, 0);
112163
}
113164
queue.submit(Some(encoder.finish()));
114165

115-
let buf_slice = output_buf.slice(..);
166+
let buf_slice = config_buf.slice(..);
116167
let buf_future = buf_slice.map_async(wgpu::MapMode::Read);
117168
let query_slice = query_buf.slice(..);
118169
let _query_future = query_slice.map_async(wgpu::MapMode::Read);
119-
println!("pre-poll {:?}", std::time::Instant::now());
120170
device.poll(wgpu::Maintain::Wait);
121-
println!("post-poll {:?}", std::time::Instant::now());
122171
if buf_future.await.is_ok() {
123172
let data_raw = &*buf_slice.get_mapped_range();
124-
let data : &[f32] = bytemuck::cast_slice(data_raw);
125-
println!("data: {:?}", &*data);
173+
let data: &[u32] = bytemuck::cast_slice(data_raw);
174+
println!("failures with strategy {}: {}", data[0], data[1]);
126175
}
127176
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
128177
let ts_period = queue.get_timestamp_period();
129178
let ts_data_raw = &*query_slice.get_mapped_range();
130-
let ts_data : &[u64] = bytemuck::cast_slice(ts_data_raw);
131-
println!("compute shader elapsed: {:?}ms", (ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6);
179+
let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw);
180+
println!(
181+
"compute shader elapsed: {:?}ms",
182+
(ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6
183+
);
132184
}
133185
}
134186

135187
fn main() {
188+
env_logger::init();
136189
pollster::block_on(run());
137190
}

compute-shader-hello/src/shader.wgsl

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,58 @@
1616

1717
[[block]]
1818
struct DataBuf {
19-
data: [[stride(4)]] array<f32>;
19+
data: [[stride(4)]] array<atomic<u32>>;
20+
};
21+
22+
[[block]]
23+
struct ControlBuf {
24+
strategy: u32;
25+
failures: atomic<u32>;
2026
};
2127

2228
[[group(0), binding(0)]]
23-
var<storage, read_write> v_indices: DataBuf;
29+
var<storage, read_write> data_buf: DataBuf;
30+
31+
[[group(0), binding(1)]]
32+
var<storage, read_write> control_buf: ControlBuf;
2433

25-
[[stage(compute), workgroup_size(1)]]
34+
// Put the flag in quite a different place than the data, which
35+
// should increase the number of failures, as they likely won't
36+
// be on the same cache line.
37+
fn permute_flag_ix(data_ix: u32) -> u32 {
38+
return (data_ix * 31u) & 0xffffu;
39+
}
40+
41+
[[stage(compute), workgroup_size(256)]]
2642
fn main([[builtin(global_invocation_id)]] global_id: vec3<u32>) {
27-
// TODO: a more interesting computation than this.
28-
v_indices.data[global_id.x] = v_indices.data[global_id.x] + 42.0;
43+
let ix = global_id.x;
44+
// Originally this was passed in, but is now hardcoded, as D3DCompiler
45+
// thinks control flow becomes nonuniform if it's read from input.
46+
let n_iter = 1024u;
47+
let strategy = control_buf.strategy;
48+
var failures = 0u;
49+
for (var i: u32 = 0u; i < n_iter; i = i + 1u) {
50+
let wr_flag_ix = permute_flag_ix(ix);
51+
atomicStore(&data_buf.data[ix * 2u], i + 1u);
52+
storageBarrier(); // release semantics for writing flag
53+
atomicStore(&data_buf.data[wr_flag_ix * 2u + 1u], i + 1u);
54+
55+
// Read from a different workgroup
56+
let read_ix = ((ix & 0xffu) << 8u) | (ix >> 8u);
57+
let read_flag_ix = permute_flag_ix(read_ix);
58+
59+
let flag = atomicLoad(&data_buf.data[read_flag_ix * 2u + 1u]);
60+
//let flag = atomicOr(&data_buf.data[read_flag_ix * 2u + 1u], 0u);
61+
storageBarrier(); // acquire semantics for reading flag
62+
var data = 0u;
63+
if (strategy == 0u) {
64+
data = atomicLoad(&data_buf.data[read_ix * 2u]);
65+
} else {
66+
data = atomicOr(&data_buf.data[read_ix * 2u], 0u);
67+
}
68+
if (flag > data) {
69+
failures = failures + 1u;
70+
}
71+
}
72+
let unused = atomicAdd(&control_buf.failures, failures);
2973
}

0 commit comments

Comments
 (0)