Skip to content

Commit fb9f91a

Browse files
committed
[wgpu-core] ray tracing: use error handling helpers
1 parent 214396a commit fb9f91a

File tree

5 files changed

+104
-242
lines changed

5 files changed

+104
-242
lines changed

wgpu-core/src/command/ray_tracing.rs

Lines changed: 38 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,7 @@ impl Global {
8181

8282
let device = &cmd_buf.device;
8383

84-
if !device
85-
.features
86-
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
87-
{
88-
return Err(BuildAccelerationStructureError::MissingFeature);
89-
}
84+
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
9085

9186
let build_command_index = NonZeroU64::new(
9287
device
@@ -199,18 +194,13 @@ impl Global {
199194
let mut tlas_buf_storage = Vec::new();
200195

201196
for entry in tlas_iter {
202-
let instance_buffer = match hub.buffers.get(entry.instance_buffer_id).get() {
203-
Ok(buffer) => buffer,
204-
Err(_) => {
205-
return Err(BuildAccelerationStructureError::InvalidBufferId);
206-
}
207-
};
197+
let instance_buffer = hub.buffers.get(entry.instance_buffer_id).get()?;
208198
let data = cmd_buf_data.trackers.buffers.set_single(
209199
&instance_buffer,
210200
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
211201
);
212202
tlas_buf_storage.push(TlasBufferStore {
213-
buffer: instance_buffer.clone(),
203+
buffer: instance_buffer,
214204
transition: data,
215205
entry: entry.clone(),
216206
});
@@ -221,14 +211,9 @@ impl Global {
221211
let instance_buffer = {
222212
let (instance_buffer, instance_pending) =
223213
(&mut tlas_buf.buffer, &mut tlas_buf.transition);
224-
let instance_raw = instance_buffer.raw.get(&snatch_guard).ok_or(
225-
BuildAccelerationStructureError::InvalidBuffer(instance_buffer.error_ident()),
226-
)?;
227-
if !instance_buffer.usage.contains(BufferUsages::TLAS_INPUT) {
228-
return Err(BuildAccelerationStructureError::MissingTlasInputUsageFlag(
229-
instance_buffer.error_ident(),
230-
));
231-
}
214+
let instance_raw = instance_buffer.try_raw(&snatch_guard)?;
215+
instance_buffer.check_usage(BufferUsages::TLAS_INPUT)?;
216+
232217
if let Some(barrier) = instance_pending
233218
.take()
234219
.map(|pending| pending.into_hal(instance_buffer, &snatch_guard))
@@ -238,11 +223,7 @@ impl Global {
238223
instance_raw
239224
};
240225

241-
let tlas = hub
242-
.tlas_s
243-
.get(entry.tlas_id)
244-
.get()
245-
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
226+
let tlas = hub.tlas_s.get(entry.tlas_id).get()?;
246227
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
247228
if let Some(queue) = device.get_queue() {
248229
queue.pending_writes.lock().insert_tlas(&tlas);
@@ -266,7 +247,7 @@ impl Global {
266247
tlas,
267248
entries: hal::AccelerationStructureEntries::Instances(
268249
hal::AccelerationStructureInstances {
269-
buffer: Some(instance_buffer.as_ref()),
250+
buffer: Some(instance_buffer),
270251
offset: 0,
271252
count: entry.instance_count,
272253
},
@@ -311,9 +292,7 @@ impl Global {
311292
mode: hal::AccelerationStructureBuildMode::Build,
312293
flags: tlas.flags,
313294
source_acceleration_structure: None,
314-
destination_acceleration_structure: tlas.raw(&snatch_guard).ok_or(
315-
BuildAccelerationStructureError::InvalidTlas(tlas.error_ident()),
316-
)?,
295+
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
317296
scratch_buffer: scratch_buffer.raw(),
318297
scratch_buffer_offset: *scratch_buffer_offset,
319298
})
@@ -381,12 +360,7 @@ impl Global {
381360

382361
let device = &cmd_buf.device;
383362

384-
if !device
385-
.features
386-
.contains(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)
387-
{
388-
return Err(BuildAccelerationStructureError::MissingFeature);
389-
}
363+
device.require_features(Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE)?;
390364

391365
let build_command_index = NonZeroU64::new(
392366
device
@@ -521,17 +495,14 @@ impl Global {
521495
let mut tlas_lock_store = Vec::<(Option<TlasPackage>, Arc<Tlas>)>::new();
522496

523497
for package in tlas_iter {
524-
let tlas = hub
525-
.tlas_s
526-
.get(package.tlas_id)
527-
.get()
528-
.map_err(|_| BuildAccelerationStructureError::InvalidTlasId)?;
498+
let tlas = hub.tlas_s.get(package.tlas_id).get()?;
529499
if let Some(queue) = device.get_queue() {
530500
queue.pending_writes.lock().insert_tlas(&tlas);
531501
}
502+
532503
cmd_buf_data.trackers.tlas_s.set_single(tlas.clone());
533504

534-
tlas_lock_store.push((Some(package), tlas.clone()))
505+
tlas_lock_store.push((Some(package), tlas))
535506
}
536507

537508
let mut scratch_buffer_tlas_size = 0;
@@ -558,12 +529,7 @@ impl Global {
558529
tlas.error_ident(),
559530
));
560531
}
561-
let blas = hub
562-
.blas_s
563-
.get(instance.blas_id)
564-
.get()
565-
.map_err(|_| BuildAccelerationStructureError::InvalidBlasIdForInstance)?
566-
.clone();
532+
let blas = hub.blas_s.get(instance.blas_id).get()?;
567533

568534
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
569535

@@ -581,7 +547,7 @@ impl Global {
581547
dependencies.push(blas.clone());
582548

583549
cmd_buf_data.blas_actions.push(BlasAction {
584-
blas: blas.clone(),
550+
blas,
585551
kind: crate::ray_tracing::BlasActionKind::Use,
586552
});
587553
}
@@ -659,13 +625,7 @@ impl Global {
659625
mode: hal::AccelerationStructureBuildMode::Build,
660626
flags: tlas.flags,
661627
source_acceleration_structure: None,
662-
destination_acceleration_structure: tlas
663-
.raw
664-
.get(&snatch_guard)
665-
.ok_or(BuildAccelerationStructureError::InvalidTlas(
666-
tlas.error_ident(),
667-
))?
668-
.as_ref(),
628+
destination_acceleration_structure: tlas.try_raw(&snatch_guard)?,
669629
scratch_buffer: scratch_buffer.raw(),
670630
scratch_buffer_offset: *scratch_buffer_offset,
671631
})
@@ -857,9 +817,7 @@ impl CommandBufferMutable {
857817
action.tlas.error_ident(),
858818
));
859819
}
860-
if blas.raw.get(snatch_guard).is_none() {
861-
return Err(ValidateTlasActionsError::InvalidBlas(blas.error_ident()));
862-
}
820+
blas.try_raw(snatch_guard)?;
863821
}
864822
}
865823
}
@@ -879,11 +837,7 @@ fn iter_blas<'a>(
879837
) -> Result<(), BuildAccelerationStructureError> {
880838
let mut temp_buffer = Vec::new();
881839
for entry in blas_iter {
882-
let blas = hub
883-
.blas_s
884-
.get(entry.blas_id)
885-
.get()
886-
.map_err(|_| BuildAccelerationStructureError::InvalidBlasId)?;
840+
let blas = hub.blas_s.get(entry.blas_id).get()?;
887841
cmd_buf_data.trackers.blas_s.set_single(blas.clone());
888842
if let Some(queue) = device.get_queue() {
889843
queue.pending_writes.lock().insert_blas(&blas);
@@ -966,19 +920,13 @@ fn iter_blas<'a>(
966920
blas.error_ident(),
967921
));
968922
}
969-
let vertex_buffer = match hub.buffers.get(mesh.vertex_buffer).get() {
970-
Ok(buffer) => buffer,
971-
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
972-
};
923+
let vertex_buffer = hub.buffers.get(mesh.vertex_buffer).get()?;
973924
let vertex_pending = cmd_buf_data.trackers.buffers.set_single(
974925
&vertex_buffer,
975926
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
976927
);
977928
let index_data = if let Some(index_id) = mesh.index_buffer {
978-
let index_buffer = match hub.buffers.get(index_id).get() {
979-
Ok(buffer) => buffer,
980-
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
981-
};
929+
let index_buffer = hub.buffers.get(index_id).get()?;
982930
if mesh.index_buffer_offset.is_none()
983931
|| mesh.size.index_count.is_none()
984932
|| mesh.size.index_count.is_none()
@@ -991,15 +939,12 @@ fn iter_blas<'a>(
991939
&index_buffer,
992940
BufferUses::BOTTOM_LEVEL_ACCELERATION_STRUCTURE_INPUT,
993941
);
994-
Some((index_buffer.clone(), data))
942+
Some((index_buffer, data))
995943
} else {
996944
None
997945
};
998946
let transform_data = if let Some(transform_id) = mesh.transform_buffer {
999-
let transform_buffer = match hub.buffers.get(transform_id).get() {
1000-
Ok(buffer) => buffer,
1001-
Err(_) => return Err(BuildAccelerationStructureError::InvalidBufferId),
1002-
};
947+
let transform_buffer = hub.buffers.get(transform_id).get()?;
1003948
if mesh.transform_buffer_offset.is_none() {
1004949
return Err(BuildAccelerationStructureError::MissingAssociatedData(
1005950
transform_buffer.error_ident(),
@@ -1014,7 +959,7 @@ fn iter_blas<'a>(
1014959
None
1015960
};
1016961
temp_buffer.push(TriangleBufferStore {
1017-
vertex_buffer: vertex_buffer.clone(),
962+
vertex_buffer,
1018963
vertex_transition: vertex_pending,
1019964
index_buffer_transition: index_data,
1020965
transform_buffer_transition: transform_data,
@@ -1024,7 +969,7 @@ fn iter_blas<'a>(
1024969
}
1025970

1026971
if let Some(last) = temp_buffer.last_mut() {
1027-
last.ending_blas = Some(blas.clone());
972+
last.ending_blas = Some(blas);
1028973
buf_storage.append(&mut temp_buffer);
1029974
}
1030975
}
@@ -1050,14 +995,9 @@ fn iter_buffers<'a, 'b>(
1050995
let mesh = &buf.geometry;
1051996
let vertex_buffer = {
1052997
let vertex_buffer = buf.vertex_buffer.as_ref();
1053-
let vertex_raw = vertex_buffer.raw.get(snatch_guard).ok_or(
1054-
BuildAccelerationStructureError::InvalidBuffer(vertex_buffer.error_ident()),
1055-
)?;
1056-
if !vertex_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
1057-
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
1058-
vertex_buffer.error_ident(),
1059-
));
1060-
}
998+
let vertex_raw = vertex_buffer.try_raw(snatch_guard)?;
999+
vertex_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1000+
10611001
if let Some(barrier) = buf
10621002
.vertex_transition
10631003
.take()
@@ -1077,10 +1017,7 @@ fn iter_buffers<'a, 'b>(
10771017
let vertex_buffer_offset = mesh.first_vertex as u64 * mesh.vertex_stride;
10781018
cmd_buf_data.buffer_memory_init_actions.extend(
10791019
vertex_buffer.initialization_status.read().create_action(
1080-
&hub.buffers
1081-
.get(mesh.vertex_buffer)
1082-
.get()
1083-
.map_err(|_| BuildAccelerationStructureError::InvalidBufferId)?,
1020+
&hub.buffers.get(mesh.vertex_buffer).get()?,
10841021
vertex_buffer_offset
10851022
..(vertex_buffer_offset
10861023
+ mesh.size.vertex_count as u64 * mesh.vertex_stride),
@@ -1092,14 +1029,9 @@ fn iter_buffers<'a, 'b>(
10921029
let index_buffer = if let Some((ref mut index_buffer, ref mut index_pending)) =
10931030
buf.index_buffer_transition
10941031
{
1095-
let index_raw = index_buffer.raw.get(snatch_guard).ok_or(
1096-
BuildAccelerationStructureError::InvalidBuffer(index_buffer.error_ident()),
1097-
)?;
1098-
if !index_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
1099-
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
1100-
index_buffer.error_ident(),
1101-
));
1102-
}
1032+
let index_raw = index_buffer.try_raw(snatch_guard)?;
1033+
index_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1034+
11031035
if let Some(barrier) = index_pending
11041036
.take()
11051037
.map(|pending| pending.into_hal(index_buffer, snatch_guard))
@@ -1155,14 +1087,9 @@ fn iter_buffers<'a, 'b>(
11551087
transform_buffer.error_ident(),
11561088
));
11571089
}
1158-
let transform_raw = transform_buffer.raw.get(snatch_guard).ok_or(
1159-
BuildAccelerationStructureError::InvalidBuffer(transform_buffer.error_ident()),
1160-
)?;
1161-
if !transform_buffer.usage.contains(BufferUsages::BLAS_INPUT) {
1162-
return Err(BuildAccelerationStructureError::MissingBlasInputUsageFlag(
1163-
transform_buffer.error_ident(),
1164-
));
1165-
}
1090+
let transform_raw = transform_buffer.try_raw(snatch_guard)?;
1091+
transform_buffer.check_usage(BufferUsages::BLAS_INPUT)?;
1092+
11661093
if let Some(barrier) = transform_pending
11671094
.take()
11681095
.map(|pending| pending.into_hal(transform_buffer, snatch_guard))
@@ -1199,7 +1126,7 @@ fn iter_buffers<'a, 'b>(
11991126
};
12001127

12011128
let triangles = hal::AccelerationStructureTriangles {
1202-
vertex_buffer: Some(vertex_buffer.as_ref()),
1129+
vertex_buffer: Some(vertex_buffer),
12031130
vertex_format: mesh.size.vertex_format,
12041131
first_vertex: mesh.first_vertex,
12051132
vertex_count: mesh.size.vertex_count,
@@ -1208,13 +1135,13 @@ fn iter_buffers<'a, 'b>(
12081135
dyn hal::DynBuffer,
12091136
> {
12101137
format: mesh.size.index_format.unwrap(),
1211-
buffer: Some(index_buffer.as_ref()),
1138+
buffer: Some(index_buffer),
12121139
offset: mesh.index_buffer_offset.unwrap() as u32,
12131140
count: mesh.size.index_count.unwrap(),
12141141
}),
12151142
transform: transform_buffer.map(|transform_buffer| {
12161143
hal::AccelerationStructureTriangleTransform {
1217-
buffer: transform_buffer.as_ref(),
1144+
buffer: transform_buffer,
12181145
offset: mesh.transform_buffer_offset.unwrap() as u32,
12191146
}
12201147
}),
@@ -1264,13 +1191,7 @@ fn map_blas<'a>(
12641191
mode: hal::AccelerationStructureBuildMode::Build,
12651192
flags: blas.flags,
12661193
source_acceleration_structure: None,
1267-
destination_acceleration_structure: blas
1268-
.raw
1269-
.get(snatch_guard)
1270-
.ok_or(BuildAccelerationStructureError::InvalidBlas(
1271-
blas.error_ident(),
1272-
))?
1273-
.as_ref(),
1194+
destination_acceleration_structure: blas.try_raw(snatch_guard)?,
12741195
scratch_buffer,
12751196
scratch_buffer_offset: *scratch_buffer_offset,
12761197
})

0 commit comments

Comments
 (0)