diff --git a/compute-shader-hello/src/main.rs b/compute-shader-hello/src/main.rs index 7888e82..2334db4 100644 --- a/compute-shader-hello/src/main.rs +++ b/compute-shader-hello/src/main.rs @@ -22,6 +22,20 @@ use wgpu::util::DeviceExt; use bytemuck; +const N_DATA: usize = 1 << 25; +const WG_SIZE: usize = 1 << 12; + +// Verify that the data is OEIS A000217 +fn verify(data: &[u32]) -> Option { + data.iter().enumerate().position(|(i, val)| { + let wrong = ((i * (i + 1)) / 2) as u32 != *val; + if wrong { + println!("diff @ {}: {} != {}", i, ((i * (i + 1)) / 2) as u32, *val); + } + wrong + }) +} + async fn run() { let instance = wgpu::Instance::new(wgpu::Backends::PRIMARY); let adapter = instance.request_adapter(&Default::default()).await.unwrap(); @@ -30,7 +44,7 @@ async fn run() { .request_device( &wgpu::DeviceDescriptor { label: None, - features: features & wgpu::Features::TIMESTAMP_QUERY, + features: features & (wgpu::Features::TIMESTAMP_QUERY | wgpu::Features::CLEAR_COMMANDS), limits: Default::default(), }, None, @@ -54,13 +68,12 @@ async fn run() { source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()), }); println!("shader compilation {:?}", start_instant.elapsed()); - let input_f = &[1.0f32, 2.0f32]; - let input : &[u8] = bytemuck::bytes_of(input_f); + let input_f: Vec = (0..N_DATA as u32).collect(); + let input: &[u8] = bytemuck::cast_slice(&input_f); let input_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: None, contents: input, usage: wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::COPY_SRC, }); let output_buf = device.create_buffer(&wgpu::BufferDescriptor { @@ -69,6 +82,15 @@ async fn run() { usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); + const N_WG: usize = N_DATA / WG_SIZE; + const STATE_SIZE: usize = N_WG * 3 + 1; + // TODO: round this up + let state_buf = device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: 4 * STATE_SIZE as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); // This works if the buffer is initialized, otherwise reads all 0, for some reason. let query_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: None, @@ -87,48 +109,66 @@ async fn run() { let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { label: None, layout: &bind_group_layout, - entries: &[wgpu::BindGroupEntry { - binding: 0, - resource: input_buf.as_entire_binding(), - }], + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: input_buf.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: state_buf.as_entire_binding(), + }, + ], }); - let mut encoder = device.create_command_encoder(&Default::default()); - if let Some(query_set) = &query_set { - encoder.write_timestamp(query_set, 0); - } - { - let mut cpass = encoder.begin_compute_pass(&Default::default()); - cpass.set_pipeline(&pipeline); - cpass.set_bind_group(0, &bind_group, &[]); - cpass.dispatch(input_f.len() as u32, 1, 1); - } - if let Some(query_set) = &query_set { - encoder.write_timestamp(query_set, 1); - } - encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64); - if let Some(query_set) = &query_set { - encoder.resolve_query_set(query_set, 0..2, &query_buf, 0); - } - queue.submit(Some(encoder.finish())); + for i in 0..100 { + let mut encoder = device.create_command_encoder(&Default::default()); + if let Some(query_set) = &query_set { + encoder.write_timestamp(query_set, 0); + } + encoder.clear_buffer(&state_buf, 0, None); + { + let mut cpass = encoder.begin_compute_pass(&Default::default()); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch(N_WG as u32, 1, 1); + } + if let Some(query_set) = &query_set { + encoder.write_timestamp(query_set, 1); + } + if i == 0 { + encoder.copy_buffer_to_buffer(&input_buf, 0, &output_buf, 0, input.len() as u64); + } + if let Some(query_set) = &query_set { + encoder.resolve_query_set(query_set, 0..2, &query_buf, 0); + } + queue.submit(Some(encoder.finish())); - let buf_slice = output_buf.slice(..); - let buf_future = buf_slice.map_async(wgpu::MapMode::Read); - let query_slice = query_buf.slice(..); - let _query_future = query_slice.map_async(wgpu::MapMode::Read); - println!("pre-poll {:?}", std::time::Instant::now()); - device.poll(wgpu::Maintain::Wait); - println!("post-poll {:?}", std::time::Instant::now()); - if buf_future.await.is_ok() { - let data_raw = &*buf_slice.get_mapped_range(); - let data : &[f32] = bytemuck::cast_slice(data_raw); - println!("data: {:?}", &*data); - } - if features.contains(wgpu::Features::TIMESTAMP_QUERY) { - let ts_period = queue.get_timestamp_period(); - let ts_data_raw = &*query_slice.get_mapped_range(); - let ts_data : &[u64] = bytemuck::cast_slice(ts_data_raw); - println!("compute shader elapsed: {:?}ms", (ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6); + let buf_slice = output_buf.slice(..); + let buf_future = buf_slice.map_async(wgpu::MapMode::Read); + let query_slice = query_buf.slice(..); + let query_future = query_slice.map_async(wgpu::MapMode::Read); + device.poll(wgpu::Maintain::Wait); + if buf_future.await.is_ok() { + if i == 0 { + let data_raw = &*buf_slice.get_mapped_range(); + let data: &[u32] = bytemuck::cast_slice(data_raw); + println!("results correct: {:?}", verify(data)); + } + output_buf.unmap(); + } + if query_future.await.is_ok() { + if features.contains(wgpu::Features::TIMESTAMP_QUERY) { + let ts_period = queue.get_timestamp_period(); + let ts_data_raw = &*query_slice.get_mapped_range(); + let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw); + println!( + "compute shader elapsed: {:?}ms", + (ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6 + ); + } + } + query_buf.unmap(); } } diff --git a/compute-shader-hello/src/shader.wgsl b/compute-shader-hello/src/shader.wgsl index 468c96e..233586d 100644 --- a/compute-shader-hello/src/shader.wgsl +++ b/compute-shader-hello/src/shader.wgsl @@ -16,14 +16,122 @@ [[block]] struct DataBuf { - data: [[stride(4)]] array; + data: [[stride(4)]] array; +}; + +[[block]] +struct StateBuf { + state: [[stride(4)]] array>; }; [[group(0), binding(0)]] -var v_indices: DataBuf; +var main_buf: DataBuf; + +[[group(0), binding(1)]] +var state_buf: StateBuf; + +let FLAG_NOT_READY = 0u; +let FLAG_AGGREGATE_READY = 1u; +let FLAG_PREFIX_READY = 2u; + +let workgroup_size: u32 = 512u; +let N_SEQ = 8u; + +var part_id: u32; +var scratch: array; +var shared_prefix: u32; +var shared_flag: u32; + +[[stage(compute), workgroup_size(512)]] +fn main([[builtin(local_invocation_id)]] local_id: vec3) { + if (local_id.x == 0u) { + part_id = atomicAdd(&state_buf.state[0], 1u); + } + workgroupBarrier(); + let my_part_id = part_id; + let mem_base = my_part_id * workgroup_size; + var local: array; + var el = main_buf.data[(mem_base + local_id.x) * N_SEQ]; + local[0] = el; + for (var i: u32 = 1u; i < N_SEQ; i = i + 1u) { + el = el + main_buf.data[(mem_base + local_id.x) * N_SEQ + i]; + local[i] = el; + } + scratch[local_id.x] = el; + // This must be lg2(workgroup_size) + for (var i: u32 = 0u; i < 9u; i = i + 1u) { + workgroupBarrier(); + if (local_id.x >= (1u << i)) { + el = el + scratch[local_id.x - (1u << i)]; + } + workgroupBarrier(); + scratch[local_id.x] = el; + } + var exclusive_prefix = 0u; + + var flag = FLAG_AGGREGATE_READY; + if (local_id.x == workgroup_size - 1u) { + atomicStore(&state_buf.state[my_part_id * 3u + 2u], el); + if (my_part_id == 0u) { + atomicStore(&state_buf.state[my_part_id * 3u + 3u], el); + flag = FLAG_PREFIX_READY; + } + } + // make sure these barriers are in uniform control flow + storageBarrier(); + if (local_id.x == workgroup_size - 1u) { + atomicStore(&state_buf.state[my_part_id * 3u + 1u], flag); + } + + if (my_part_id != 0u) { + // decoupled look-back + var look_back_ix = my_part_id - 1u; + loop { + if (local_id.x == workgroup_size - 1u) { + shared_flag = atomicOr(&state_buf.state[look_back_ix * 3u + 1u], 0u); + } + workgroupBarrier(); + flag = shared_flag; + storageBarrier(); + if (flag == FLAG_PREFIX_READY) { + if (local_id.x == workgroup_size - 1u) { + let their_prefix = atomicOr(&state_buf.state[look_back_ix * 3u + 3u], 0u); + exclusive_prefix = their_prefix + exclusive_prefix; + } + break; + } elseif (flag == FLAG_AGGREGATE_READY) { + if (local_id.x == workgroup_size - 1u) { + let their_agg = atomicOr(&state_buf.state[look_back_ix * 3u + 2u], 0u); + exclusive_prefix = their_agg + exclusive_prefix; + } + look_back_ix = look_back_ix - 1u; + } + // else spin + } + + // compute inclusive prefix + if (local_id.x == workgroup_size - 1u) { + let inclusive_prefix = exclusive_prefix + el; + shared_prefix = exclusive_prefix; + atomicStore(&state_buf.state[my_part_id * 3u + 3u], inclusive_prefix); + } + storageBarrier(); + if (local_id.x == workgroup_size - 1u) { + atomicStore(&state_buf.state[my_part_id * 3u + 1u], FLAG_PREFIX_READY); + } + } + var prefix = 0u; + workgroupBarrier(); + if (my_part_id != 0u) { + prefix = shared_prefix; + } -[[stage(compute), workgroup_size(1)]] -fn main([[builtin(global_invocation_id)]] global_id: vec3) { - // TODO: a more interesting computation than this. - v_indices.data[global_id.x] = v_indices.data[global_id.x] + 42.0; + // do final output + for (var i: u32 = 0u; i < N_SEQ; i = i + 1u) { + var old = 0u; + if (local_id.x > 0u) { + old = scratch[local_id.x - 1u]; + } + main_buf.data[(mem_base + local_id.x) * N_SEQ + i] = prefix + old + local[i]; + } }