16
16
17
17
//! A simple application to run a compute shader.
18
18
19
- mod encode;
20
-
21
19
use std:: time:: Instant ;
22
20
23
21
use wgpu:: util:: DeviceExt ;
24
22
25
- use encode :: Codable ;
23
+ use bytemuck ;
26
24
27
25
async fn run ( ) {
28
- let instance = wgpu:: Instance :: new ( wgpu:: BackendBit :: PRIMARY ) ;
26
+ let instance = wgpu:: Instance :: new ( wgpu:: Backends :: PRIMARY ) ;
29
27
let adapter = instance. request_adapter ( & Default :: default ( ) ) . await . unwrap ( ) ;
30
28
let features = adapter. features ( ) ;
31
29
let ( device, queue) = adapter
@@ -39,17 +37,11 @@ async fn run() {
39
37
)
40
38
. await
41
39
. unwrap ( ) ;
42
- let mut shader_flags = wgpu:: ShaderFlags :: VALIDATION ;
43
- if matches ! (
44
- adapter. get_info( ) . backend,
45
- wgpu:: Backend :: Vulkan | wgpu:: Backend :: Metal | wgpu:: Backend :: Gl
46
- ) {
47
- shader_flags |= wgpu:: ShaderFlags :: EXPERIMENTAL_TRANSLATION ;
48
- }
49
40
let query_set = if features. contains ( wgpu:: Features :: TIMESTAMP_QUERY ) {
50
41
Some ( device. create_query_set ( & wgpu:: QuerySetDescriptor {
51
42
count : 2 ,
52
43
ty : wgpu:: QueryType :: Timestamp ,
44
+ label : None ,
53
45
} ) )
54
46
} else {
55
47
None
@@ -60,28 +52,28 @@ async fn run() {
60
52
label : None ,
61
53
//source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
62
54
source : wgpu:: ShaderSource :: Wgsl ( include_str ! ( "shader.wgsl" ) . into ( ) ) ,
63
- flags : shader_flags,
64
55
} ) ;
65
56
println ! ( "shader compilation {:?}" , start_instant. elapsed( ) ) ;
66
- let input: Vec < u8 > = Codable :: encode_vec ( & [ 1.0f32 , 2.0f32 ] ) ;
57
+ let input_f = & [ 1.0f32 , 2.0f32 ] ;
58
+ let input : & [ u8 ] = bytemuck:: bytes_of ( input_f) ;
67
59
let input_buf = device. create_buffer_init ( & wgpu:: util:: BufferInitDescriptor {
68
60
label : None ,
69
- contents : & input,
70
- usage : wgpu:: BufferUsage :: STORAGE
71
- | wgpu:: BufferUsage :: COPY_DST
72
- | wgpu:: BufferUsage :: COPY_SRC ,
61
+ contents : input,
62
+ usage : wgpu:: BufferUsages :: STORAGE
63
+ | wgpu:: BufferUsages :: COPY_DST
64
+ | wgpu:: BufferUsages :: COPY_SRC ,
73
65
} ) ;
74
66
let output_buf = device. create_buffer ( & wgpu:: BufferDescriptor {
75
67
label : None ,
76
68
size : input. len ( ) as u64 ,
77
- usage : wgpu:: BufferUsage :: MAP_READ | wgpu:: BufferUsage :: COPY_DST ,
69
+ usage : wgpu:: BufferUsages :: MAP_READ | wgpu:: BufferUsages :: COPY_DST ,
78
70
mapped_at_creation : false ,
79
71
} ) ;
80
72
// This works if the buffer is initialized, otherwise reads all 0, for some reason.
81
73
let query_buf = device. create_buffer_init ( & wgpu:: util:: BufferInitDescriptor {
82
74
label : None ,
83
75
contents : & [ 0 ; 16 ] ,
84
- usage : wgpu:: BufferUsage :: MAP_READ | wgpu:: BufferUsage :: COPY_DST ,
76
+ usage : wgpu:: BufferUsages :: MAP_READ | wgpu:: BufferUsages :: COPY_DST ,
85
77
} ) ;
86
78
87
79
let pipeline = device. create_compute_pipeline ( & wgpu:: ComputePipelineDescriptor {
@@ -109,7 +101,7 @@ async fn run() {
109
101
let mut cpass = encoder. begin_compute_pass ( & Default :: default ( ) ) ;
110
102
cpass. set_pipeline ( & pipeline) ;
111
103
cpass. set_bind_group ( 0 , & bind_group, & [ ] ) ;
112
- cpass. dispatch ( 1 , 1 , 1 ) ;
104
+ cpass. dispatch ( input_f . len ( ) as u32 , 1 , 1 ) ;
113
105
}
114
106
if let Some ( query_set) = & query_set {
115
107
encoder. write_timestamp ( query_set, 1 ) ;
@@ -128,32 +120,18 @@ async fn run() {
128
120
device. poll ( wgpu:: Maintain :: Wait ) ;
129
121
println ! ( "post-poll {:?}" , std:: time:: Instant :: now( ) ) ;
130
122
if buf_future. await . is_ok ( ) {
131
- let data = buf_slice. get_mapped_range ( ) ;
123
+ let data_raw = & * buf_slice. get_mapped_range ( ) ;
124
+ let data : & [ f32 ] = bytemuck:: cast_slice ( data_raw) ;
132
125
println ! ( "data: {:?}" , & * data) ;
133
126
}
134
127
if features. contains ( wgpu:: Features :: TIMESTAMP_QUERY ) {
135
128
let ts_period = queue. get_timestamp_period ( ) ;
136
- let ts_data: Vec < u64 > = Codable :: decode_vec ( & * query_slice. get_mapped_range ( ) ) ;
137
- let ts_data = ts_data
138
- . iter ( )
139
- . map ( |ts| * ts as f64 * ts_period as f64 * 1e-6 )
140
- . collect :: < Vec < _ > > ( ) ;
141
- println ! ( "compute shader elapsed: {:?}ms" , ts_data[ 1 ] - ts_data[ 0 ] ) ;
129
+ let ts_data_raw = & * query_slice. get_mapped_range ( ) ;
130
+ let ts_data : & [ u64 ] = bytemuck:: cast_slice ( ts_data_raw) ;
131
+ println ! ( "compute shader elapsed: {:?}ms" , ( ts_data[ 1 ] - ts_data[ 0 ] ) as f64 * ts_period as f64 * 1e-6 ) ;
142
132
}
143
133
}
144
134
145
- #[ allow( unused) ]
146
- fn bytes_to_u32 ( bytes : & [ u8 ] ) -> Vec < u32 > {
147
- bytes
148
- . chunks_exact ( 4 )
149
- . map ( |b| {
150
- let mut bytes = [ 0 ; 4 ] ;
151
- bytes. copy_from_slice ( b) ;
152
- u32:: from_le_bytes ( bytes)
153
- } )
154
- . collect ( )
155
- }
156
-
157
135
fn main ( ) {
158
136
pollster:: block_on ( run ( ) ) ;
159
137
}
0 commit comments