Skip to content

Commit 4f91b78

Browse files
committed
Change computenbody shader to better match glsl
This enables much higher perf
1 parent cc5f6d2 commit 4f91b78

File tree

2 files changed

+40
-17
lines changed

2 files changed

+40
-17
lines changed
Binary file not shown.

shaders/rust/computenbody/particle_calculate/src/lib.rs

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#![cfg_attr(target_arch = "spirv", no_std)]
22

33
use spirv_std::{
4-
glam::{vec3, vec4, Vec3, Vec4},
4+
glam::{vec3, vec4, Vec4, Vec4Swizzles},
55
spirv,
6+
arch::workgroup_memory_barrier_with_group_sync,
67
num_traits::Float,
78
};
89

@@ -23,40 +24,62 @@ pub struct UBO {
2324
pub soften: f32,
2425
}
2526

27+
const SHARED_DATA_SIZE: usize = 512;
28+
2629
#[spirv(compute(threads(256)))]
2730
pub fn main_cs(
2831
#[spirv(global_invocation_id)] global_id: spirv_std::glam::UVec3,
32+
#[spirv(local_invocation_id)] local_id: spirv_std::glam::UVec3,
33+
#[spirv(workgroup)] shared_data: &mut [Vec4; SHARED_DATA_SIZE],
2934
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] particles: &mut [Particle],
3035
#[spirv(uniform, descriptor_set = 0, binding = 1)] ubo: &UBO,
3136
) {
3237
let index = global_id.x as usize;
38+
let local_index = local_id.x as usize;
3339

3440
if index >= ubo.particle_count as usize {
3541
return;
3642
}
3743

3844
let position = vec4(particles[index].pos[0], particles[index].pos[1], particles[index].pos[2], particles[index].pos[3]);
3945
let mut velocity = vec4(particles[index].vel[0], particles[index].vel[1], particles[index].vel[2], particles[index].vel[3]);
40-
let mut acceleration = vec4(0.0, 0.0, 0.0, 0.0);
46+
let mut acceleration = vec3(0.0, 0.0, 0.0);
4147

42-
// Calculate forces from all other particles (simplified O(N²) approach)
43-
for i in 0..ubo.particle_count as usize {
44-
if i == index {
45-
continue; // Skip self-interaction
48+
// Process particles in chunks of SHARED_DATA_SIZE
49+
let mut i = 0u32;
50+
while i < ubo.particle_count {
51+
// Load particle data into shared memory
52+
if i + (local_index as u32) < ubo.particle_count {
53+
let particle_idx = i as usize + local_index;
54+
shared_data[local_index] = vec4(
55+
particles[particle_idx].pos[0],
56+
particles[particle_idx].pos[1],
57+
particles[particle_idx].pos[2],
58+
particles[particle_idx].pos[3]
59+
);
60+
} else {
61+
shared_data[local_index] = vec4(0.0, 0.0, 0.0, 0.0);
62+
}
63+
64+
// Ensure all threads have loaded their data
65+
unsafe {
66+
workgroup_memory_barrier_with_group_sync();
4667
}
4768

48-
let other = vec4(particles[i].pos[0], particles[i].pos[1], particles[i].pos[2], particles[i].pos[3]);
49-
let len = vec3(other.x - position.x, other.y - position.y, other.z - position.z);
50-
let distance_sq = len.dot(len) + ubo.soften;
51-
let distance = distance_sq.sqrt();
52-
let force_magnitude = ubo.gravity * other.w / distance_sq.powf(ubo.power / 2.0);
69+
// Calculate forces from particles in shared memory
70+
for j in 0..256 { // gl_WorkGroupSize.x = 256
71+
let other = shared_data[j];
72+
let len = other.xyz() - position.xyz();
73+
let distance_sq = len.dot(len) + ubo.soften;
74+
acceleration += ubo.gravity * len * other.w / distance_sq.powf(ubo.power * 0.5);
75+
}
76+
77+
// Synchronize before next iteration
78+
unsafe {
79+
workgroup_memory_barrier_with_group_sync();
80+
}
5381

54-
acceleration = acceleration + vec4(
55-
len.x * force_magnitude,
56-
len.y * force_magnitude,
57-
len.z * force_magnitude,
58-
0.0
59-
);
82+
i += SHARED_DATA_SIZE as u32;
6083
}
6184

6285
// Update velocity with acceleration

0 commit comments

Comments
 (0)