Skip to content

Commit f815534

Browse files
Firestar99LegNeato
authored andcommitted
mesh shaders: added per_primitive_ext output attribute
1 parent 3500363 commit f815534

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub enum SpirvAttribute {
9292
DescriptorSet(u32),
9393
Binding(u32),
9494
Flat,
95+
PerPrimitiveExt,
9596
Invariant,
9697
InputAttachmentIndex(u32),
9798
SpecConstant(SpecConstant),
@@ -128,6 +129,7 @@ pub struct AggregatedSpirvAttributes {
128129
pub binding: Option<Spanned<u32>>,
129130
pub flat: Option<Spanned<()>>,
130131
pub invariant: Option<Spanned<()>>,
132+
pub per_primitive_ext: Option<Spanned<()>>,
131133
pub input_attachment_index: Option<Spanned<u32>>,
132134
pub spec_constant: Option<Spanned<SpecConstant>>,
133135

@@ -213,6 +215,12 @@ impl AggregatedSpirvAttributes {
213215
Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"),
214216
Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"),
215217
Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"),
218+
PerPrimitiveExt => try_insert(
219+
&mut self.per_primitive_ext,
220+
(),
221+
span,
222+
"#[spirv(per_primitive_ext)]",
223+
),
216224
InputAttachmentIndex(value) => try_insert(
217225
&mut self.input_attachment_index,
218226
value,
@@ -314,6 +322,7 @@ impl CheckSpirvAttrVisitor<'_> {
314322
| SpirvAttribute::Binding(_)
315323
| SpirvAttribute::Flat
316324
| SpirvAttribute::Invariant
325+
| SpirvAttribute::PerPrimitiveExt
317326
| SpirvAttribute::InputAttachmentIndex(_)
318327
| SpirvAttribute::SpecConstant(_) => match target {
319328
Target::Param => {

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,27 @@ impl<'tcx> CodegenCx<'tcx> {
732732
self.emit_global()
733733
.decorate(var_id.unwrap(), Decoration::Invariant, std::iter::empty());
734734
}
735+
if let Some(per_primitive_ext) = attrs.per_primitive_ext {
736+
if storage_class != Ok(StorageClass::Output) {
737+
self.tcx.dcx().span_fatal(
738+
per_primitive_ext.span,
739+
"`#[spirv(per_primitive_ext)]` is only valid on Output variables",
740+
);
741+
}
742+
if !(execution_model == ExecutionModel::MeshEXT
743+
|| execution_model == ExecutionModel::MeshNV)
744+
{
745+
self.tcx.dcx().span_fatal(
746+
per_primitive_ext.span,
747+
"`#[spirv(per_primitive_ext)]` is only valid in mesh shaders",
748+
);
749+
}
750+
self.emit_global().decorate(
751+
var_id.unwrap(),
752+
Decoration::PerPrimitiveEXT,
753+
std::iter::empty(),
754+
);
755+
}
735756

736757
let is_subpass_input = match self.lookup_type(value_spirv_type) {
737758
SpirvType::Image {

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ impl Symbols {
355355
("block", SpirvAttribute::Block),
356356
("flat", SpirvAttribute::Flat),
357357
("invariant", SpirvAttribute::Invariant),
358+
("per_primitive_ext", SpirvAttribute::PerPrimitiveExt),
358359
(
359360
"sampled_image",
360361
SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage),
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// build-pass
2+
// only-vulkan1.2
3+
// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader
4+
5+
use spirv_std::arch::set_mesh_outputs_ext;
6+
use spirv_std::glam::{UVec3, Vec4};
7+
use spirv_std::spirv;
8+
9+
#[spirv(mesh_ext(
10+
threads(1),
11+
output_vertices = 3,
12+
output_primitives_ext = 1,
13+
output_triangles_ext
14+
))]
15+
pub fn main(
16+
#[spirv(position)] positions: &mut [Vec4; 3],
17+
out_per_vertex: &mut [u32; 3],
18+
#[spirv(per_primitive_ext)] out_per_primitive: &mut [u32; 1],
19+
#[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1],
20+
) {
21+
unsafe {
22+
set_mesh_outputs_ext(3, 1);
23+
}
24+
25+
positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0);
26+
positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0);
27+
positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0);
28+
out_per_vertex[0] = 0;
29+
out_per_vertex[1] = 1;
30+
out_per_vertex[2] = 2;
31+
32+
indices[0] = UVec3::new(0, 1, 2);
33+
out_per_primitive[0] = 42;
34+
}

0 commit comments

Comments
 (0)