@@ -733,20 +733,31 @@ impl<'tcx> CodegenCx<'tcx> {
733
733
. decorate ( var_id. unwrap ( ) , Decoration :: Invariant , std:: iter:: empty ( ) ) ;
734
734
}
735
735
if let Some ( per_primitive_ext) = attrs. per_primitive_ext {
736
- if storage_class != Ok ( StorageClass :: Output ) {
737
- self . tcx . dcx ( ) . span_fatal (
738
- per_primitive_ext. span ,
739
- "`#[spirv(per_primitive_ext)]` is only valid on Output variables" ,
740
- ) ;
741
- }
742
- if !( execution_model == ExecutionModel :: MeshEXT
743
- || execution_model == ExecutionModel :: MeshNV )
744
- {
745
- self . tcx . dcx ( ) . span_fatal (
746
- per_primitive_ext. span ,
747
- "`#[spirv(per_primitive_ext)]` is only valid in mesh shaders" ,
748
- ) ;
736
+ match execution_model {
737
+ ExecutionModel :: Fragment => {
738
+ if storage_class != Ok ( StorageClass :: Input ) {
739
+ self . tcx . dcx ( ) . span_fatal (
740
+ per_primitive_ext. span ,
741
+ "`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables" ,
742
+ ) ;
743
+ }
744
+ }
745
+ ExecutionModel :: MeshNV | ExecutionModel :: MeshEXT => {
746
+ if storage_class != Ok ( StorageClass :: Output ) {
747
+ self . tcx . dcx ( ) . span_fatal (
748
+ per_primitive_ext. span ,
749
+ "`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables" ,
750
+ ) ;
751
+ }
752
+ }
753
+ _ => {
754
+ self . tcx . dcx ( ) . span_fatal (
755
+ per_primitive_ext. span ,
756
+ "`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders" ,
757
+ ) ;
758
+ }
749
759
}
760
+
750
761
self . emit_global ( ) . decorate (
751
762
var_id. unwrap ( ) ,
752
763
Decoration :: PerPrimitiveEXT ,
0 commit comments