@@ -243,30 +243,28 @@ impl<A: hal::Api> PendingWrites<A> {
243
243
}
244
244
}
245
245
246
- impl < A : HalApi > super :: Device < A > {
247
- fn prepare_staging_buffer (
248
- & mut self ,
249
- size : wgt:: BufferAddress ,
250
- ) -> Result < ( StagingBuffer < A > , * mut u8 ) , DeviceError > {
251
- profiling:: scope!( "prepare_staging_buffer" ) ;
252
- let stage_desc = hal:: BufferDescriptor {
253
- label : Some ( "(wgpu internal) Staging" ) ,
254
- size,
255
- usage : hal:: BufferUses :: MAP_WRITE | hal:: BufferUses :: COPY_SRC ,
256
- memory_flags : hal:: MemoryFlags :: TRANSIENT ,
257
- } ;
258
-
259
- let buffer = unsafe { self . raw . create_buffer ( & stage_desc) ? } ;
260
- let mapping = unsafe { self . raw . map_buffer ( & buffer, 0 ..size) } ?;
261
-
262
- let staging_buffer = StagingBuffer {
263
- raw : buffer,
264
- size,
265
- is_coherent : mapping. is_coherent ,
266
- } ;
267
-
268
- Ok ( ( staging_buffer, mapping. ptr . as_ptr ( ) ) )
269
- }
246
+ fn prepare_staging_buffer < A : HalApi > (
247
+ device : & mut A :: Device ,
248
+ size : wgt:: BufferAddress ,
249
+ ) -> Result < ( StagingBuffer < A > , * mut u8 ) , DeviceError > {
250
+ profiling:: scope!( "prepare_staging_buffer" ) ;
251
+ let stage_desc = hal:: BufferDescriptor {
252
+ label : Some ( "(wgpu internal) Staging" ) ,
253
+ size,
254
+ usage : hal:: BufferUses :: MAP_WRITE | hal:: BufferUses :: COPY_SRC ,
255
+ memory_flags : hal:: MemoryFlags :: TRANSIENT ,
256
+ } ;
257
+
258
+ let buffer = unsafe { device. create_buffer ( & stage_desc) ? } ;
259
+ let mapping = unsafe { device. map_buffer ( & buffer, 0 ..size) } ?;
260
+
261
+ let staging_buffer = StagingBuffer {
262
+ raw : buffer,
263
+ size,
264
+ is_coherent : mapping. is_coherent ,
265
+ } ;
266
+
267
+ Ok ( ( staging_buffer, mapping. ptr . as_ptr ( ) ) )
270
268
}
271
269
272
270
impl < A : hal:: Api > StagingBuffer < A > {
@@ -350,21 +348,31 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
350
348
return Ok ( ( ) ) ;
351
349
}
352
350
353
- let ( staging_buffer, staging_buffer_ptr) = device. prepare_staging_buffer ( data_size) ?;
351
+ // Platform validation requires that the staging buffer always be
352
+ // freed, even if an error occurs. All paths from here must call
353
+ // `device.pending_writes.consume`.
354
+ let ( staging_buffer, staging_buffer_ptr) =
355
+ prepare_staging_buffer ( & mut device. raw , data_size) ?;
354
356
355
- unsafe {
357
+ if let Err ( flush_error ) = unsafe {
356
358
profiling:: scope!( "copy" ) ;
357
359
ptr:: copy_nonoverlapping ( data. as_ptr ( ) , staging_buffer_ptr, data. len ( ) ) ;
358
- staging_buffer. flush ( & device. raw ) ?;
359
- } ;
360
+ staging_buffer. flush ( & device. raw )
361
+ } {
362
+ device. pending_writes . consume ( staging_buffer) ;
363
+ return Err ( flush_error. into ( ) ) ;
364
+ }
360
365
361
- self . queue_write_staging_buffer_impl (
366
+ let result = self . queue_write_staging_buffer_impl (
362
367
device,
363
368
device_token,
364
- staging_buffer,
369
+ & staging_buffer,
365
370
buffer_id,
366
371
buffer_offset,
367
- )
372
+ ) ;
373
+
374
+ device. pending_writes . consume ( staging_buffer) ;
375
+ result
368
376
}
369
377
370
378
pub fn queue_create_staging_buffer < A : HalApi > (
@@ -382,7 +390,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
382
390
. map_err ( |_| DeviceError :: Invalid ) ?;
383
391
384
392
let ( staging_buffer, staging_buffer_ptr) =
385
- device . prepare_staging_buffer ( buffer_size. get ( ) ) ?;
393
+ prepare_staging_buffer ( & mut device . raw , buffer_size. get ( ) ) ?;
386
394
387
395
let fid = hub. staging_buffers . prepare ( id_in) ;
388
396
let id = fid. assign ( staging_buffer, device_token) ;
@@ -413,15 +421,25 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
413
421
. 0
414
422
. ok_or ( TransferError :: InvalidBuffer ( buffer_id) ) ?;
415
423
416
- unsafe { staging_buffer. flush ( & device. raw ) ? } ;
424
+ // At this point, we have taken ownership of the staging_buffer from the
425
+ // user. Platform validation requires that the staging buffer always
426
+ // be freed, even if an error occurs. All paths from here must call
427
+ // `device.pending_writes.consume`.
428
+ if let Err ( flush_error) = unsafe { staging_buffer. flush ( & device. raw ) } {
429
+ device. pending_writes . consume ( staging_buffer) ;
430
+ return Err ( flush_error. into ( ) ) ;
431
+ }
417
432
418
- self . queue_write_staging_buffer_impl (
433
+ let result = self . queue_write_staging_buffer_impl (
419
434
device,
420
435
device_token,
421
- staging_buffer,
436
+ & staging_buffer,
422
437
buffer_id,
423
438
buffer_offset,
424
- )
439
+ ) ;
440
+
441
+ device. pending_writes . consume ( staging_buffer) ;
442
+ result
425
443
}
426
444
427
445
pub fn queue_validate_write_buffer < A : HalApi > (
@@ -481,7 +499,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
481
499
& self ,
482
500
device : & mut super :: Device < A > ,
483
501
device_token : & mut Token < super :: Device < A > > ,
484
- staging_buffer : StagingBuffer < A > ,
502
+ staging_buffer : & StagingBuffer < A > ,
485
503
buffer_id : id:: BufferId ,
486
504
buffer_offset : u64 ,
487
505
) -> Result < ( ) , QueueWriteError > {
@@ -520,7 +538,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
520
538
encoder. copy_buffer_to_buffer ( & staging_buffer. raw , dst_raw, region. into_iter ( ) ) ;
521
539
}
522
540
523
- device. pending_writes . consume ( staging_buffer) ;
524
541
device. pending_writes . dst_buffers . insert ( buffer_id) ;
525
542
526
543
// Ensure the overwritten bytes are marked as initialized so they don't need to be nulled prior to mapping or binding.
@@ -613,7 +630,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
613
630
let block_rows_in_copy =
614
631
( size. depth_or_array_layers - 1 ) * block_rows_per_image + height_blocks;
615
632
let stage_size = stage_bytes_per_row as u64 * block_rows_in_copy as u64 ;
616
- let ( staging_buffer, staging_buffer_ptr) = device. prepare_staging_buffer ( stage_size) ?;
617
633
618
634
let dst = texture_guard. get_mut ( destination. texture ) . unwrap ( ) ;
619
635
if !dst. desc . usage . contains ( wgt:: TextureUsages :: COPY_DST ) {
@@ -676,12 +692,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
676
692
validate_texture_copy_range ( destination, & dst. desc , CopySide :: Destination , size) ?;
677
693
dst. life_guard . use_at ( device. active_submission_index + 1 ) ;
678
694
695
+ let dst_raw = dst
696
+ . inner
697
+ . as_raw ( )
698
+ . ok_or ( TransferError :: InvalidTexture ( destination. texture ) ) ?;
699
+
679
700
let bytes_per_row = if let Some ( bytes_per_row) = data_layout. bytes_per_row {
680
701
bytes_per_row. get ( )
681
702
} else {
682
703
width_blocks * format_desc. block_size as u32
683
704
} ;
684
705
706
+ // Platform validation requires that the staging buffer always be
707
+ // freed, even if an error occurs. All paths from here must call
708
+ // `device.pending_writes.consume`.
709
+ let ( staging_buffer, staging_buffer_ptr) =
710
+ prepare_staging_buffer ( & mut device. raw , stage_size) ?;
711
+
685
712
if stage_bytes_per_row == bytes_per_row {
686
713
profiling:: scope!( "copy aligned" ) ;
687
714
// Fast path if the data is already being aligned optimally.
@@ -715,7 +742,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
715
742
}
716
743
}
717
744
718
- unsafe { staging_buffer. flush ( & device. raw ) } ?;
745
+ if let Err ( e) = unsafe { staging_buffer. flush ( & device. raw ) } {
746
+ device. pending_writes . consume ( staging_buffer) ;
747
+ return Err ( e. into ( ) ) ;
748
+ }
719
749
720
750
let regions = ( 0 ..array_layer_count) . map ( |rel_array_layer| {
721
751
let mut texture_base = dst_base. clone ( ) ;
@@ -737,11 +767,6 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
737
767
usage : hal:: BufferUses :: MAP_WRITE ..hal:: BufferUses :: COPY_SRC ,
738
768
} ;
739
769
740
- let dst_raw = dst
741
- . inner
742
- . as_raw ( )
743
- . ok_or ( TransferError :: InvalidTexture ( destination. texture ) ) ?;
744
-
745
770
unsafe {
746
771
encoder
747
772
. transition_textures ( transition. map ( |pending| pending. into_hal ( dst) ) . into_iter ( ) ) ;
0 commit comments