Skip to content

Commit e40e66d

Browse files
DXIL & HLSL passthrough (#7831)
Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
1 parent 4c08c37 commit e40e66d

File tree

11 files changed

+259
-56
lines changed

11 files changed

+259
-56
lines changed

wgpu-core/src/device/global.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,18 @@ impl Global {
984984
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
985985
}
986986
}
987+
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
988+
pipeline::ShaderModuleDescriptor {
989+
label: inner.label.clone(),
990+
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
991+
}
992+
}
993+
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
994+
pipeline::ShaderModuleDescriptor {
995+
label: inner.label.clone(),
996+
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
997+
}
998+
}
987999
},
9881000
data,
9891001
});

wgpu-core/src/device/resource.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,6 +1801,22 @@ impl Device {
18011801
num_workgroups: inner.num_workgroups,
18021802
}
18031803
}
1804+
pipeline::ShaderModuleDescriptorPassthrough::Dxil(inner) => {
1805+
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
1806+
hal::ShaderInput::Dxil {
1807+
shader: inner.source,
1808+
entry_point: inner.entry_point.clone(),
1809+
num_workgroups: inner.num_workgroups,
1810+
}
1811+
}
1812+
pipeline::ShaderModuleDescriptorPassthrough::Hlsl(inner) => {
1813+
self.require_features(wgt::Features::HLSL_DXIL_SHADER_PASSTHROUGH)?;
1814+
hal::ShaderInput::Hlsl {
1815+
shader: inner.source,
1816+
entry_point: inner.entry_point.clone(),
1817+
num_workgroups: inner.num_workgroups,
1818+
}
1819+
}
18041820
};
18051821

18061822
let hal_desc = hal::ShaderModuleDescriptor {

wgpu-hal/src/dx12/device.rs

Lines changed: 106 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use alloc::borrow::ToOwned;
12
use alloc::{
23
borrow::Cow,
34
string::{String, ToString as _},
@@ -264,27 +265,8 @@ impl super::Device {
264265
naga_stage: naga::ShaderStage,
265266
fragment_stage: Option<&crate::ProgrammableStage<super::ShaderModule>>,
266267
) -> Result<super::CompiledShader, crate::PipelineError> {
267-
use naga::back::hlsl;
268-
269-
let frag_ep = fragment_stage
270-
.map(|fs_stage| {
271-
hlsl::FragmentEntryPoint::new(&fs_stage.module.naga.module, fs_stage.entry_point)
272-
.ok_or(crate::PipelineError::EntryPoint(
273-
naga::ShaderStage::Fragment,
274-
))
275-
})
276-
.transpose()?;
277-
278268
let stage_bit = auxil::map_naga_stage(naga_stage);
279269

280-
let (module, info) = naga::back::pipeline_constants::process_overrides(
281-
&stage.module.naga.module,
282-
&stage.module.naga.info,
283-
Some((naga_stage, stage.entry_point)),
284-
stage.constants,
285-
)
286-
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?;
287-
288270
let needs_temp_options = stage.zero_initialize_workgroup_memory
289271
!= layout.naga_options.zero_initialize_workgroup_memory
290272
|| stage.module.runtime_checks.bounds_checks != layout.naga_options.restrict_indexing
@@ -301,43 +283,90 @@ impl super::Device {
301283
&layout.naga_options
302284
};
303285

304-
let pipeline_options = hlsl::PipelineOptions {
305-
entry_point: Some((naga_stage, stage.entry_point.to_string())),
306-
};
286+
let key = match &stage.module.source {
287+
super::ShaderModuleSource::Naga(naga_shader) => {
288+
use naga::back::hlsl;
307289

308-
//TODO: reuse the writer
309-
let (source, entry_point) = {
310-
let mut source = String::new();
311-
let mut writer = hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
290+
let frag_ep = match fragment_stage {
291+
Some(crate::ProgrammableStage {
292+
module:
293+
super::ShaderModule {
294+
source: super::ShaderModuleSource::Naga(naga_shader),
295+
..
296+
},
297+
entry_point,
298+
..
299+
}) => Some(
300+
hlsl::FragmentEntryPoint::new(&naga_shader.module, entry_point).ok_or(
301+
crate::PipelineError::EntryPoint(naga::ShaderStage::Fragment),
302+
),
303+
),
304+
_ => None,
305+
}
306+
.transpose()?;
307+
let (module, info) = naga::back::pipeline_constants::process_overrides(
308+
&naga_shader.module,
309+
&naga_shader.info,
310+
Some((naga_stage, stage.entry_point)),
311+
stage.constants,
312+
)
313+
.map_err(|e| {
314+
crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}"))
315+
})?;
312316

313-
profiling::scope!("naga::back::hlsl::write");
314-
let mut reflection_info = writer
315-
.write(&module, &info, frag_ep.as_ref())
316-
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
317+
let pipeline_options = hlsl::PipelineOptions {
318+
entry_point: Some((naga_stage, stage.entry_point.to_string())),
319+
};
317320

318-
assert_eq!(reflection_info.entry_point_names.len(), 1);
321+
//TODO: reuse the writer
322+
let (source, entry_point) = {
323+
let mut source = String::new();
324+
let mut writer =
325+
hlsl::Writer::new(&mut source, naga_options, &pipeline_options);
319326

320-
let entry_point = reflection_info
321-
.entry_point_names
322-
.pop()
323-
.unwrap()
324-
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
327+
profiling::scope!("naga::back::hlsl::write");
328+
let mut reflection_info = writer
329+
.write(&module, &info, frag_ep.as_ref())
330+
.map_err(|e| {
331+
crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}"))
332+
})?;
325333

326-
(source, entry_point)
327-
};
334+
assert_eq!(reflection_info.entry_point_names.len(), 1);
328335

329-
log::info!(
330-
"Naga generated shader for {:?} at {:?}:\n{}",
331-
entry_point,
332-
naga_stage,
333-
source
334-
);
336+
let entry_point = reflection_info
337+
.entry_point_names
338+
.pop()
339+
.unwrap()
340+
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
335341

336-
let key = ShaderCacheKey {
337-
source,
338-
entry_point,
339-
stage: naga_stage,
340-
shader_model: naga_options.shader_model,
342+
(source, entry_point)
343+
};
344+
log::info!(
345+
"Naga generated shader for {:?} at {:?}:\n{}",
346+
entry_point,
347+
naga_stage,
348+
source
349+
);
350+
351+
ShaderCacheKey {
352+
source,
353+
entry_point,
354+
stage: naga_stage,
355+
shader_model: naga_options.shader_model,
356+
}
357+
}
358+
super::ShaderModuleSource::HlslPassthrough(passthrough) => ShaderCacheKey {
359+
source: passthrough.shader.clone(),
360+
entry_point: passthrough.entry_point.clone(),
361+
stage: naga_stage,
362+
shader_model: naga_options.shader_model,
363+
},
364+
365+
super::ShaderModuleSource::DxilPassthrough(passthrough) => {
366+
return Ok(super::CompiledShader::Precompiled(
367+
passthrough.shader.clone(),
368+
))
369+
}
341370
};
342371

343372
{
@@ -351,11 +380,7 @@ impl super::Device {
351380

352381
let source_name = stage.module.raw_name.as_deref();
353382

354-
let full_stage = format!(
355-
"{}_{}",
356-
naga_stage.to_hlsl_str(),
357-
naga_options.shader_model.to_str()
358-
);
383+
let full_stage = format!("{}_{}", naga_stage.to_hlsl_str(), key.shader_model.to_str());
359384

360385
let compiled_shader = self.compiler_container.compile(
361386
self,
@@ -1671,7 +1696,7 @@ impl crate::Device for super::Device {
16711696
.and_then(|label| alloc::ffi::CString::new(label).ok());
16721697
match shader {
16731698
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
1674-
naga,
1699+
source: super::ShaderModuleSource::Naga(naga),
16751700
raw_name,
16761701
runtime_checks: desc.runtime_checks,
16771702
}),
@@ -1681,6 +1706,32 @@ impl crate::Device for super::Device {
16811706
crate::ShaderInput::Msl { .. } => {
16821707
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
16831708
}
1709+
crate::ShaderInput::Dxil {
1710+
shader,
1711+
entry_point,
1712+
num_workgroups,
1713+
} => Ok(super::ShaderModule {
1714+
source: super::ShaderModuleSource::DxilPassthrough(super::DxilPassthroughShader {
1715+
shader: shader.to_vec(),
1716+
entry_point,
1717+
num_workgroups,
1718+
}),
1719+
raw_name,
1720+
runtime_checks: desc.runtime_checks,
1721+
}),
1722+
crate::ShaderInput::Hlsl {
1723+
shader,
1724+
entry_point,
1725+
num_workgroups,
1726+
} => Ok(super::ShaderModule {
1727+
source: super::ShaderModuleSource::HlslPassthrough(super::HlslPassthroughShader {
1728+
shader: shader.to_owned(),
1729+
entry_point,
1730+
num_workgroups,
1731+
}),
1732+
raw_name,
1733+
runtime_checks: desc.runtime_checks,
1734+
}),
16841735
}
16851736
}
16861737
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {

wgpu-hal/src/dx12/mod.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ impl crate::DynPipelineLayout for PipelineLayout {}
10771077

10781078
#[derive(Debug)]
10791079
pub struct ShaderModule {
1080-
naga: crate::NagaShader,
1080+
source: ShaderModuleSource,
10811081
raw_name: Option<alloc::ffi::CString>,
10821082
runtime_checks: wgt::ShaderRuntimeChecks,
10831083
}
@@ -1109,6 +1109,7 @@ pub(super) struct ShaderCacheValue {
11091109
pub(super) enum CompiledShader {
11101110
Dxc(Direct3D::Dxc::IDxcBlob),
11111111
Fxc(Direct3D::ID3DBlob),
1112+
Precompiled(Vec<u8>),
11121113
}
11131114

11141115
impl CompiledShader {
@@ -1122,6 +1123,10 @@ impl CompiledShader {
11221123
pShaderBytecode: unsafe { shader.GetBufferPointer() },
11231124
BytecodeLength: unsafe { shader.GetBufferSize() },
11241125
},
1126+
CompiledShader::Precompiled(shader) => Direct3D12::D3D12_SHADER_BYTECODE {
1127+
pShaderBytecode: shader.as_ptr().cast(),
1128+
BytecodeLength: shader.len(),
1129+
},
11251130
}
11261131
}
11271132
}
@@ -1490,3 +1495,23 @@ impl crate::Queue for Queue {
14901495
(1_000_000_000.0 / frequency as f64) as f32
14911496
}
14921497
}
1498+
#[derive(Debug)]
1499+
pub struct DxilPassthroughShader {
1500+
pub shader: Vec<u8>,
1501+
pub entry_point: String,
1502+
pub num_workgroups: (u32, u32, u32),
1503+
}
1504+
1505+
#[derive(Debug)]
1506+
pub struct HlslPassthroughShader {
1507+
pub shader: String,
1508+
pub entry_point: String,
1509+
pub num_workgroups: (u32, u32, u32),
1510+
}
1511+
1512+
#[derive(Debug)]
1513+
pub enum ShaderModuleSource {
1514+
Naga(crate::NagaShader),
1515+
DxilPassthrough(DxilPassthroughShader),
1516+
HlslPassthrough(HlslPassthroughShader),
1517+
}

wgpu-hal/src/gles/device.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,9 @@ impl crate::Device for super::Device {
13461346
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
13471347
}
13481348
crate::ShaderInput::Naga(naga) => naga,
1349+
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
1350+
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
1351+
}
13491352
},
13501353
label: desc.label.map(|str| str.to_string()),
13511354
id: self.shared.next_shader_id.fetch_add(1, Ordering::Relaxed),

wgpu-hal/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,16 @@ pub enum ShaderInput<'a> {
21042104
num_workgroups: (u32, u32, u32),
21052105
},
21062106
SpirV(&'a [u32]),
2107+
Dxil {
2108+
shader: &'a [u8],
2109+
entry_point: String,
2110+
num_workgroups: (u32, u32, u32),
2111+
},
2112+
Hlsl {
2113+
shader: &'a str,
2114+
entry_point: String,
2115+
num_workgroups: (u32, u32, u32),
2116+
},
21072117
}
21082118

21092119
pub struct ShaderModuleDescriptor<'a> {

wgpu-hal/src/metal/device.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,9 @@ impl crate::Device for super::Device {
10391039
crate::ShaderInput::SpirV(_) => {
10401040
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
10411041
}
1042+
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
1043+
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled for this backend")
1044+
}
10421045
}
10431046
}
10441047

wgpu-hal/src/vulkan/device.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,9 @@ impl crate::Device for super::Device {
19081908
crate::ShaderInput::Msl { .. } => {
19091909
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
19101910
}
1911+
crate::ShaderInput::Dxil { .. } | crate::ShaderInput::Hlsl { .. } => {
1912+
panic!("`Features::HLSL_DXIL_SHADER_PASSTHROUGH` is not enabled")
1913+
}
19111914
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
19121915
};
19131916

wgpu-types/src/features.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,16 @@ bitflags_array! {
12441244
///
12451245
/// [BlasTriangleGeometrySizeDescriptor::vertex_format]: super::BlasTriangleGeometrySizeDescriptor
12461246
const EXTENDED_ACCELERATION_STRUCTURE_VERTEX_FORMATS = 1 << 51;
1247+
1248+
/// Enables creating shader modules from DirectX HLSL or DXIL shaders (unsafe)
1249+
///
1250+
/// HLSL/DXIL data is not parsed or interpreted in any way
1251+
///
1252+
/// Supported platforms:
1253+
/// - DX12
1254+
///
1255+
/// This is a native only feature.
1256+
const HLSL_DXIL_SHADER_PASSTHROUGH = 1 << 53;
12471257
}
12481258

12491259
/// Features that are not guaranteed to be supported.

0 commit comments

Comments
 (0)