Skip to content

Commit 39848e8

Browse files
committed
[d3d12] add a shader cache to avoid calling into DXC/FXC
1 parent 8d3ade9 commit 39848e8

File tree

2 files changed

+93
-42
lines changed

2 files changed

+93
-42
lines changed

wgpu-hal/src/dx12/device.rs

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::{
2626
},
2727
dx12::{
2828
borrow_optional_interface_temporarily, shader_compilation, suballocation,
29-
DynamicStorageBufferOffsets, Event,
29+
DynamicStorageBufferOffsets, Event, ShaderCacheKey, ShaderCacheValue,
3030
},
3131
AccelerationStructureEntries, TlasInstance,
3232
};
@@ -203,6 +203,7 @@ impl super::Device {
203203
null_rtv_handle,
204204
mem_allocator,
205205
dxc_container,
206+
shader_cache: Default::default(),
206207
counters: Default::default(),
207208
})
208209
}
@@ -304,63 +305,98 @@ impl super::Device {
304305
};
305306

306307
//TODO: reuse the writer
307-
let mut source = String::new();
308-
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
309-
let reflection_info = {
308+
let (source, entry_point) = {
309+
let mut source = String::new();
310+
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
311+
310312
profiling::scope!("naga::back::hlsl::write");
311-
writer
313+
let mut reflection_info = writer
312314
.write(&module, &info, frag_ep.as_ref())
313-
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?
315+
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
316+
317+
assert_eq!(reflection_info.entry_point_names.len(), 1);
318+
319+
let entry_point = reflection_info
320+
.entry_point_names
321+
.pop()
322+
.unwrap()
323+
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
324+
325+
(source, entry_point)
314326
};
315327

328+
log::info!(
329+
"Naga generated shader for {:?} at {:?}:\n{}",
330+
entry_point,
331+
naga_stage,
332+
source
333+
);
334+
335+
let key = ShaderCacheKey {
336+
source,
337+
entry_point,
338+
stage: naga_stage,
339+
shader_model: naga_options.shader_model,
340+
};
341+
342+
{
343+
let mut shader_cache = self.shader_cache.lock();
344+
let nr_of_shaders_compiled = shader_cache.nr_of_shaders_compiled;
345+
if let Some(value) = shader_cache.entries.get_mut(&key) {
346+
value.last_used = nr_of_shaders_compiled;
347+
return Ok(value.shader.clone());
348+
}
349+
}
350+
351+
let source_name = stage.module.raw_name.as_deref();
352+
316353
let full_stage = format!(
317354
"{}_{}",
318355
naga_stage.to_hlsl_str(),
319356
naga_options.shader_model.to_str()
320357
);
321358

322-
let raw_ep = reflection_info.entry_point_names[0]
323-
.as_ref()
324-
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
325-
326-
let source_name = stage.module.raw_name.as_deref();
327-
328359
// Compile with DXC if available, otherwise fall back to FXC
329-
let result = if let Some(ref dxc_container) = self.dxc_container {
360+
let compiled_shader = if let Some(ref dxc_container) = self.dxc_container {
330361
shader_compilation::compile_dxc(
331362
self,
332-
&source,
363+
&key.source,
333364
source_name,
334-
raw_ep,
365+
&key.entry_point,
335366
stage_bit,
336367
&full_stage,
337368
dxc_container,
338-
)
369+
)?
339370
} else {
340371
shader_compilation::compile_fxc(
341372
self,
342-
&source,
373+
&key.source,
343374
source_name,
344-
raw_ep,
375+
&key.entry_point,
345376
stage_bit,
346377
&full_stage,
347-
)
378+
)?
348379
};
349380

350-
let log_level = if result.is_ok() {
351-
log::Level::Info
352-
} else {
353-
log::Level::Error
354-
};
381+
{
382+
let mut shader_cache = self.shader_cache.lock();
383+
shader_cache.nr_of_shaders_compiled += 1;
384+
let nr_of_shaders_compiled = shader_cache.nr_of_shaders_compiled;
385+
let value = ShaderCacheValue {
386+
last_used: nr_of_shaders_compiled,
387+
shader: compiled_shader.clone(),
388+
};
389+
shader_cache.entries.insert(key, value);
355390

356-
log::log!(
357-
log_level,
358-
"Naga generated shader for {:?} at {:?}:\n{}",
359-
raw_ep,
360-
naga_stage,
361-
source
362-
);
363-
result
391+
// Retain all entries that have been used since we compiled the last 100 shaders.
392+
if shader_cache.entries.len() > 200 {
393+
shader_cache
394+
.entries
395+
.retain(|_, v| v.last_used >= nr_of_shaders_compiled - 100);
396+
}
397+
}
398+
399+
Ok(compiled_shader)
364400
}
365401

366402
pub fn raw_device(&self) -> &Direct3D12::ID3D12Device {
@@ -1831,11 +1867,6 @@ impl crate::Device for super::Device {
18311867
}
18321868
.map_err(|err| crate::PipelineError::Linkage(shader_stages, err.to_string()))?;
18331869

1834-
unsafe { blob_vs.destroy() };
1835-
if let Some(blob_fs) = blob_fs {
1836-
unsafe { blob_fs.destroy() };
1837-
};
1838-
18391870
if let Some(label) = desc.label {
18401871
raw.set_name(label)?;
18411872
}
@@ -1893,8 +1924,6 @@ impl crate::Device for super::Device {
18931924
}
18941925
};
18951926

1896-
unsafe { blob_cs.destroy() };
1897-
18981927
let raw: Direct3D12::ID3D12PipelineState = pair.map_err(|err| {
18991928
crate::PipelineError::Linkage(wgt::ShaderStages::COMPUTE, err.to_string())
19001929
})?;

wgpu-hal/src/dx12/mod.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ mod suballocation;
8484
mod types;
8585
mod view;
8686

87-
use alloc::{borrow::ToOwned as _, sync::Arc, vec::Vec};
87+
use alloc::{borrow::ToOwned as _, string::String, sync::Arc, vec::Vec};
8888
use core::{ffi, fmt, mem, num::NonZeroU32, ops::Deref};
8989

9090
use arrayvec::ArrayVec;
91+
use hashbrown::HashMap;
9192
use parking_lot::{Mutex, RwLock};
9293
use suballocation::Allocator;
9394
use windows::{
@@ -656,6 +657,7 @@ pub struct Device {
656657
null_rtv_handle: descriptor::Handle,
657658
mem_allocator: Allocator,
658659
dxc_container: Option<Arc<shader_compilation::DxcContainer>>,
660+
shader_cache: Mutex<ShaderCache>,
659661
counters: Arc<wgt::HalCounters>,
660662
}
661663

@@ -1077,6 +1079,28 @@ pub struct ShaderModule {
10771079

10781080
impl crate::DynShaderModule for ShaderModule {}
10791081

1082+
#[derive(Default)]
1083+
pub struct ShaderCache {
1084+
nr_of_shaders_compiled: u32,
1085+
entries: HashMap<ShaderCacheKey, ShaderCacheValue>,
1086+
}
1087+
1088+
#[derive(PartialEq, Eq, Hash)]
1089+
pub(super) struct ShaderCacheKey {
1090+
source: String,
1091+
entry_point: String,
1092+
stage: naga::ShaderStage,
1093+
shader_model: naga::back::hlsl::ShaderModel,
1094+
}
1095+
1096+
pub(super) struct ShaderCacheValue {
1097+
/// This is the value of [`ShaderCache::nr_of_shaders_compiled`]
1098+
/// at the time the cache entry was last used.
1099+
last_used: u32,
1100+
shader: CompiledShader,
1101+
}
1102+
1103+
#[derive(Clone)]
10801104
pub(super) enum CompiledShader {
10811105
Dxc(Direct3D::Dxc::IDxcBlob),
10821106
Fxc(Direct3D::ID3DBlob),
@@ -1095,8 +1119,6 @@ impl CompiledShader {
10951119
},
10961120
}
10971121
}
1098-
1099-
unsafe fn destroy(self) {}
11001122
}
11011123

11021124
#[derive(Debug)]

0 commit comments

Comments
 (0)