@@ -198,7 +198,7 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
198
198
( "tessellation_evaluation" , TessellationEvaluation ) ,
199
199
( "geometry" , Geometry ) ,
200
200
( "fragment" , Fragment ) ,
201
- ( "gl_compute " , GLCompute ) ,
201
+ ( "compute " , GLCompute ) ,
202
202
( "kernel" , Kernel ) ,
203
203
( "task_nv" , TaskNV ) ,
204
204
( "mesh_nv" , MeshNV ) ,
@@ -218,6 +218,7 @@ enum ExecutionModeExtraDim {
218
218
X ,
219
219
Y ,
220
220
Z ,
221
+ Tuple ,
221
222
}
222
223
223
224
const EXECUTION_MODES : & [ ( & str , ExecutionMode , ExecutionModeExtraDim ) ] = {
@@ -240,9 +241,7 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
240
241
( "depth_greater" , DepthGreater , None ) ,
241
242
( "depth_less" , DepthLess , None ) ,
242
243
( "depth_unchanged" , DepthUnchanged , None ) ,
243
- ( "local_size_x" , LocalSize , X ) ,
244
- ( "local_size_y" , LocalSize , Y ) ,
245
- ( "local_size_z" , LocalSize , Z ) ,
244
+ ( "threads" , LocalSize , Tuple ) ,
246
245
( "local_size_hint_x" , LocalSizeHint , X ) ,
247
246
( "local_size_hint_y" , LocalSizeHint , Y ) ,
248
247
( "local_size_hint_z" , LocalSizeHint , Z ) ,
@@ -690,6 +689,40 @@ fn parse_attr_int_value(arg: &NestedMetaItem) -> Result<u32, ParseAttrError> {
690
689
}
691
690
}
692
691
692
+ fn parse_local_size_attr ( arg : & NestedMetaItem ) -> Result < [ u32 ; 3 ] , ParseAttrError > {
693
+ let arg = match arg. meta_item ( ) {
694
+ Some ( arg) => arg,
695
+ None => return Err ( ( arg. span ( ) , "attribute must have value" . to_string ( ) ) ) ,
696
+ } ;
697
+ match arg. meta_item_list ( ) {
698
+ Some ( tuple) if !tuple. is_empty ( ) && tuple. len ( ) < 4 => {
699
+ let mut local_size = [ 1 ; 3 ] ;
700
+ for ( idx, lit) in tuple. iter ( ) . enumerate ( ) {
701
+ match lit. literal ( ) {
702
+ Some ( & Lit {
703
+ kind : LitKind :: Int ( x, LitIntType :: Unsuffixed ) ,
704
+ ..
705
+ } ) if x <= u32:: MAX as u128 => local_size[ idx] = x as u32 ,
706
+ _ => return Err ( ( lit. span ( ) , "must be a u32 literal" . to_string ( ) ) ) ,
707
+ }
708
+ }
709
+ Ok ( local_size)
710
+ }
711
+ Some ( tuple) if tuple. is_empty ( ) => Err ( (
712
+ arg. span ,
713
+ "#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided" . to_string ( ) ,
714
+ ) ) ,
715
+ Some ( tuple) if tuple. len ( ) > 3 => Err ( (
716
+ arg. span ,
717
+ "#[spirv(compute(threads(x, y, z)))] is three dimensional" . to_string ( ) ,
718
+ ) ) ,
719
+ _ => Err ( (
720
+ arg. span ,
721
+ "#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided" . to_string ( ) ,
722
+ ) ) ,
723
+ }
724
+ }
725
+
693
726
// for a given entry, gather up the additional attributes
694
727
// in this case ExecutionMode's, some have extra arguments
695
728
// others are specified with x, y, or z components
@@ -715,30 +748,23 @@ fn parse_entry_attrs(
715
748
{
716
749
use ExecutionModeExtraDim :: * ;
717
750
let val = match extra_dim {
718
- None => Option :: None ,
751
+ None | Tuple => Option :: None ,
719
752
_ => Some ( parse_attr_int_value ( attr) ?) ,
720
753
} ;
721
754
match execution_mode {
722
755
OriginUpperLeft | OriginLowerLeft => {
723
756
origin_mode. replace ( * execution_mode) ;
724
757
}
725
758
LocalSize => {
726
- let val = val. unwrap ( ) ;
727
759
if local_size. is_none ( ) {
728
- local_size. replace ( [ 1 , 1 , 1 ] ) ;
729
- }
730
- let local_size = local_size. as_mut ( ) . unwrap ( ) ;
731
- match extra_dim {
732
- X => {
733
- local_size[ 0 ] = val;
734
- }
735
- Y => {
736
- local_size[ 1 ] = val;
737
- }
738
- Z => {
739
- local_size[ 2 ] = val;
740
- }
741
- _ => unreachable ! ( ) ,
760
+ local_size. replace ( parse_local_size_attr ( attr) ?) ;
761
+ } else {
762
+ return Err ( (
763
+ attr_name. span ,
764
+ String :: from (
765
+ "`#[spirv(compute(threads))]` may only be specified once" ,
766
+ ) ,
767
+ ) ) ;
742
768
}
743
769
}
744
770
LocalSizeHint => {
@@ -838,10 +864,18 @@ fn parse_entry_attrs(
838
864
. push ( ( origin_mode, ExecutionModeExtra :: new ( [ ] ) ) ) ;
839
865
}
840
866
GLCompute => {
841
- let local_size = local_size. unwrap_or ( [ 1 , 1 , 1 ] ) ;
842
- entry
843
- . execution_modes
844
- . push ( ( LocalSize , ExecutionModeExtra :: new ( local_size) ) ) ;
867
+ if let Some ( local_size) = local_size {
868
+ entry
869
+ . execution_modes
870
+ . push ( ( LocalSize , ExecutionModeExtra :: new ( local_size) ) ) ;
871
+ } else {
872
+ return Err ( (
873
+ arg. span ( ) ,
874
+ String :: from (
875
+ "The `threads` argument must be specified when using `#[spirv(compute)]`" ,
876
+ ) ,
877
+ ) ) ;
878
+ }
845
879
}
846
880
Kernel => {
847
881
if let Some ( local_size) = local_size {
0 commit comments