diff --git a/compute-shader-hello/Cargo.lock b/compute-shader-hello/Cargo.lock index b9e26f0..db28b50 100644 --- a/compute-shader-hello/Cargo.lock +++ b/compute-shader-hello/Cargo.lock @@ -204,6 +204,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "fixedbitset" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "398ea4fabe40b9b0d885340a2a991a44c8a645624075ad966d21f88688e2b69e" + [[package]] name = "foreign-types" version = "0.3.2" @@ -445,6 +451,7 @@ dependencies = [ "indexmap", "log", "num-traits", + "petgraph", "spirv", "thiserror", ] @@ -502,6 +509,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "petgraph" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a13a2fa9d0b63e5f22328828741e523766fff0ee9e779316902290dff3f824f" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pollster" version = "0.2.4" @@ -753,6 +770,7 @@ dependencies = [ "arrayvec", "js-sys", "log", + "naga", "parking_lot", "raw-window-handle", "smallvec", diff --git a/compute-shader-hello/Cargo.toml b/compute-shader-hello/Cargo.toml index 2916859..95869b3 100644 --- a/compute-shader-hello/Cargo.toml +++ b/compute-shader-hello/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" resolver = "2" [dependencies] -wgpu = "0.11.0" +wgpu = { version = "0.11.0", features = ["spirv"] } env_logger = "0.8" pollster = "0.2" bytemuck = { version = "1.7", features = ["derive"] } diff --git a/compute-shader-hello/src/main.rs b/compute-shader-hello/src/main.rs index 7888e82..9921620 100644 --- a/compute-shader-hello/src/main.rs +++ b/compute-shader-hello/src/main.rs @@ -16,21 +16,23 @@ //! A simple application to run a compute shader. -use std::time::Instant; - -use wgpu::util::DeviceExt; - use bytemuck; +const USE_SPIRV: bool = false; + async fn run() { let instance = wgpu::Instance::new(wgpu::Backends::PRIMARY); let adapter = instance.request_adapter(&Default::default()).await.unwrap(); let features = adapter.features(); + let mut feature_mask = wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::CLEAR_COMMANDS; + if USE_SPIRV { + feature_mask |= wgpu::Features::SPIRV_SHADER_PASSTHROUGH; + } let (device, queue) = adapter .request_device( &wgpu::DeviceDescriptor { label: None, - features: features & wgpu::Features::TIMESTAMP_QUERY, + features: features & feature_mask, limits: Default::default(), }, None, @@ -47,91 +49,121 @@ async fn run() { None }; - let start_instant = Instant::now(); + let source = if USE_SPIRV { + wgpu::util::make_spirv(include_bytes!("shader.spv")) + } else { + wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()) + }; let cs_module = device.create_shader_module(&wgpu::ShaderModuleDescriptor { label: None, - //source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()), - source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()), + source, }); - println!("shader compilation {:?}", start_instant.elapsed()); - let input_f = &[1.0f32, 2.0f32]; - let input : &[u8] = bytemuck::bytes_of(input_f); - let input_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + + let data_buf = device.create_buffer(&wgpu::BufferDescriptor { label: None, - contents: input, - usage: wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::COPY_SRC, + size: 0x80000, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, }); - let output_buf = device.create_buffer(&wgpu::BufferDescriptor { + let config_buf = device.create_buffer(&wgpu::BufferDescriptor { label: None, - size: input.len() as u64, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + size: 8, + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::MAP_READ + | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); - // This works if the buffer is initialized, otherwise reads all 0, for some reason. - let query_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + }); + let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, - contents: &[0; 16], - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], }); - let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, - layout: None, + layout: Some(&compute_pipeline_layout), module: &cs_module, entry_point: "main", }); - let bind_group_layout = pipeline.get_bind_group_layout(0); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { label: None, layout: &bind_group_layout, - entries: &[wgpu::BindGroupEntry { - binding: 0, - resource: input_buf.as_entire_binding(), - }], + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: data_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: config_buf.as_entire_binding(), + }, + ], }); - let mut encoder = device.create_command_encoder(&Default::default()); - if let Some(query_set) = &query_set { - encoder.write_timestamp(query_set, 0); - } - { - let mut cpass = encoder.begin_compute_pass(&Default::default()); - cpass.set_pipeline(&pipeline); - cpass.set_bind_group(0, &bind_group, &[]); - cpass.dispatch(input_f.len() as u32, 1, 1); - } - if let Some(query_set) = &query_set { - encoder.write_timestamp(query_set, 1); - } - encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64); - if let Some(query_set) = &query_set { - encoder.resolve_query_set(query_set, 0..2, &query_buf, 0); - } - queue.submit(Some(encoder.finish())); + let mut failures = 0; + for i in 0..1000 { + let mut encoder = device.create_command_encoder(&Default::default()); + if let Some(query_set) = &query_set { + encoder.write_timestamp(query_set, 0); + } + encoder.clear_buffer(&config_buf, 0, None); + encoder.clear_buffer(&data_buf, 0, None); + { + let mut cpass = encoder.begin_compute_pass(&Default::default()); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch(256, 1, 1); + } + queue.submit(Some(encoder.finish())); - let buf_slice = output_buf.slice(..); - let buf_future = buf_slice.map_async(wgpu::MapMode::Read); - let query_slice = query_buf.slice(..); - let _query_future = query_slice.map_async(wgpu::MapMode::Read); - println!("pre-poll {:?}", std::time::Instant::now()); - device.poll(wgpu::Maintain::Wait); - println!("post-poll {:?}", std::time::Instant::now()); - if buf_future.await.is_ok() { - let data_raw = &*buf_slice.get_mapped_range(); - let data : &[f32] = bytemuck::cast_slice(data_raw); - println!("data: {:?}", &*data); + let buf_slice = config_buf.slice(..); + let buf_future = buf_slice.map_async(wgpu::MapMode::Read); + device.poll(wgpu::Maintain::Wait); + if buf_future.await.is_ok() { + let data_raw = buf_slice.get_mapped_range(); + let data: &[u32] = bytemuck::cast_slice(&*data_raw); + if data[1] != 0 { + if failures == 0 { + println!("first failing iteration {}, failures: {}", i, data[1]); + } + failures += data[1]; + } + std::mem::drop(data_raw); + config_buf.unmap(); + } } - if features.contains(wgpu::Features::TIMESTAMP_QUERY) { - let ts_period = queue.get_timestamp_period(); - let ts_data_raw = &*query_slice.get_mapped_range(); - let ts_data : &[u64] = bytemuck::cast_slice(ts_data_raw); - println!("compute shader elapsed: {:?}ms", (ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6); + if failures != 0 { + println!("{} total failures", failures); } } fn main() { + env_logger::init(); pollster::block_on(run()); } diff --git a/compute-shader-hello/src/shader.wgsl b/compute-shader-hello/src/shader.wgsl index 468c96e..eb10001 100644 --- a/compute-shader-hello/src/shader.wgsl +++ b/compute-shader-hello/src/shader.wgsl @@ -14,16 +14,52 @@ // // Also licensed under MIT license, at your choice. +struct Element { + data: atomic; + flag: atomic; +}; + [[block]] struct DataBuf { - data: [[stride(4)]] array; + data: [[stride(8)]] array; +}; + +[[block]] +struct ControlBuf { + strategy: u32; + failures: atomic; }; [[group(0), binding(0)]] -var v_indices: DataBuf; +var data_buf: DataBuf; + +[[group(0), binding(1)]] +var control_buf: ControlBuf; -[[stage(compute), workgroup_size(1)]] +// Put the flag in quite a different place than the data, which +// should increase the number of failures, as they likely won't +// be on the same cache line. +fn permute_flag_ix(data_ix: u32) -> u32 { + return (data_ix * 419u) & 0xffffu; +} + +[[stage(compute), workgroup_size(256)]] fn main([[builtin(global_invocation_id)]] global_id: vec3) { - // TODO: a more interesting computation than this. - v_indices.data[global_id.x] = v_indices.data[global_id.x] + 42.0; + let ix = global_id.x; + + let wr_flag_ix = permute_flag_ix(ix); + atomicStore(&data_buf.data[ix].data, 1u); + storageBarrier(); // release semantics for writing flag + atomicStore(&data_buf.data[wr_flag_ix].flag, 1u); + + // Read from a different workgroup + let read_ix = (ix * 4099u) & 0xffffu; + let read_flag_ix = permute_flag_ix(read_ix); + + let flag = atomicLoad(&data_buf.data[read_flag_ix].flag); + storageBarrier(); // acquire semantics for reading flag + let data = atomicLoad(&data_buf.data[read_ix].data); + if (flag > data) { + let unused = atomicAdd(&control_buf.failures, 1u); + } }