1
1
#![ cfg_attr( target_arch = "spirv" , no_std) ]
2
2
3
3
use spirv_std:: {
4
- glam:: { vec3, vec4, Vec3 , Vec4 } ,
4
+ glam:: { vec3, vec4, Vec4 , Vec4Swizzles } ,
5
5
spirv,
6
+ arch:: workgroup_memory_barrier_with_group_sync,
6
7
num_traits:: Float ,
7
8
} ;
8
9
@@ -23,40 +24,62 @@ pub struct UBO {
23
24
pub soften : f32 ,
24
25
}
25
26
27
+ const SHARED_DATA_SIZE : usize = 512 ;
28
+
26
29
#[ spirv( compute( threads( 256 ) ) ) ]
27
30
pub fn main_cs (
28
31
#[ spirv( global_invocation_id) ] global_id : spirv_std:: glam:: UVec3 ,
32
+ #[ spirv( local_invocation_id) ] local_id : spirv_std:: glam:: UVec3 ,
33
+ #[ spirv( workgroup) ] shared_data : & mut [ Vec4 ; SHARED_DATA_SIZE ] ,
29
34
#[ spirv( storage_buffer, descriptor_set = 0 , binding = 0 ) ] particles : & mut [ Particle ] ,
30
35
#[ spirv( uniform, descriptor_set = 0 , binding = 1 ) ] ubo : & UBO ,
31
36
) {
32
37
let index = global_id. x as usize ;
38
+ let local_index = local_id. x as usize ;
33
39
34
40
if index >= ubo. particle_count as usize {
35
41
return ;
36
42
}
37
43
38
44
let position = vec4 ( particles[ index] . pos [ 0 ] , particles[ index] . pos [ 1 ] , particles[ index] . pos [ 2 ] , particles[ index] . pos [ 3 ] ) ;
39
45
let mut velocity = vec4 ( particles[ index] . vel [ 0 ] , particles[ index] . vel [ 1 ] , particles[ index] . vel [ 2 ] , particles[ index] . vel [ 3 ] ) ;
40
- let mut acceleration = vec4 ( 0.0 , 0.0 , 0.0 , 0.0 ) ;
46
+ let mut acceleration = vec3 ( 0.0 , 0.0 , 0.0 ) ;
41
47
42
- // Calculate forces from all other particles (simplified O(N²) approach)
43
- for i in 0 ..ubo. particle_count as usize {
44
- if i == index {
45
- continue ; // Skip self-interaction
48
+ // Process particles in chunks of SHARED_DATA_SIZE
49
+ let mut i = 0u32 ;
50
+ while i < ubo. particle_count {
51
+ // Load particle data into shared memory
52
+ if i + ( local_index as u32 ) < ubo. particle_count {
53
+ let particle_idx = i as usize + local_index;
54
+ shared_data[ local_index] = vec4 (
55
+ particles[ particle_idx] . pos [ 0 ] ,
56
+ particles[ particle_idx] . pos [ 1 ] ,
57
+ particles[ particle_idx] . pos [ 2 ] ,
58
+ particles[ particle_idx] . pos [ 3 ]
59
+ ) ;
60
+ } else {
61
+ shared_data[ local_index] = vec4 ( 0.0 , 0.0 , 0.0 , 0.0 ) ;
62
+ }
63
+
64
+ // Ensure all threads have loaded their data
65
+ unsafe {
66
+ workgroup_memory_barrier_with_group_sync ( ) ;
46
67
}
47
68
48
- let other = vec4 ( particles[ i] . pos [ 0 ] , particles[ i] . pos [ 1 ] , particles[ i] . pos [ 2 ] , particles[ i] . pos [ 3 ] ) ;
49
- let len = vec3 ( other. x - position. x , other. y - position. y , other. z - position. z ) ;
50
- let distance_sq = len. dot ( len) + ubo. soften ;
51
- let distance = distance_sq. sqrt ( ) ;
52
- let force_magnitude = ubo. gravity * other. w / distance_sq. powf ( ubo. power / 2.0 ) ;
69
+ // Calculate forces from particles in shared memory
70
+ for j in 0 ..256 { // gl_WorkGroupSize.x = 256
71
+ let other = shared_data[ j] ;
72
+ let len = other. xyz ( ) - position. xyz ( ) ;
73
+ let distance_sq = len. dot ( len) + ubo. soften ;
74
+ acceleration += ubo. gravity * len * other. w / distance_sq. powf ( ubo. power * 0.5 ) ;
75
+ }
76
+
77
+ // Synchronize before next iteration
78
+ unsafe {
79
+ workgroup_memory_barrier_with_group_sync ( ) ;
80
+ }
53
81
54
- acceleration = acceleration + vec4 (
55
- len. x * force_magnitude,
56
- len. y * force_magnitude,
57
- len. z * force_magnitude,
58
- 0.0
59
- ) ;
82
+ i += SHARED_DATA_SIZE as u32 ;
60
83
}
61
84
62
85
// Update velocity with acceleration
0 commit comments