Skip to content

Commit eebb2d3

Browse files
authored
Require local size x dimension and remove gl_ (#495)
* Require local size and remove gl_ Removes the gl_ prefix from the compute shader attribute, shortens the thread dimension declaration to threads(x, y, z), requires the x size dimensions be specified, trailing ones may be elided for the y or z dimensions. * Implement review suggestions
1 parent a173208 commit eebb2d3

File tree

2 files changed

+60
-25
lines changed
  • crates/rustc_codegen_spirv/src
  • examples/shaders/compute-shader/src

2 files changed

+60
-25
lines changed

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = {
198198
("tessellation_evaluation", TessellationEvaluation),
199199
("geometry", Geometry),
200200
("fragment", Fragment),
201-
("gl_compute", GLCompute),
201+
("compute", GLCompute),
202202
("kernel", Kernel),
203203
("task_nv", TaskNV),
204204
("mesh_nv", MeshNV),
@@ -218,6 +218,7 @@ enum ExecutionModeExtraDim {
218218
X,
219219
Y,
220220
Z,
221+
Tuple,
221222
}
222223

223224
const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
@@ -240,9 +241,7 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = {
240241
("depth_greater", DepthGreater, None),
241242
("depth_less", DepthLess, None),
242243
("depth_unchanged", DepthUnchanged, None),
243-
("local_size_x", LocalSize, X),
244-
("local_size_y", LocalSize, Y),
245-
("local_size_z", LocalSize, Z),
244+
("threads", LocalSize, Tuple),
246245
("local_size_hint_x", LocalSizeHint, X),
247246
("local_size_hint_y", LocalSizeHint, Y),
248247
("local_size_hint_z", LocalSizeHint, Z),
@@ -690,6 +689,40 @@ fn parse_attr_int_value(arg: &NestedMetaItem) -> Result<u32, ParseAttrError> {
690689
}
691690
}
692691

692+
fn parse_local_size_attr(arg: &NestedMetaItem) -> Result<[u32; 3], ParseAttrError> {
693+
let arg = match arg.meta_item() {
694+
Some(arg) => arg,
695+
None => return Err((arg.span(), "attribute must have value".to_string())),
696+
};
697+
match arg.meta_item_list() {
698+
Some(tuple) if !tuple.is_empty() && tuple.len() < 4 => {
699+
let mut local_size = [1; 3];
700+
for (idx, lit) in tuple.iter().enumerate() {
701+
match lit.literal() {
702+
Some(&Lit {
703+
kind: LitKind::Int(x, LitIntType::Unsuffixed),
704+
..
705+
}) if x <= u32::MAX as u128 => local_size[idx] = x as u32,
706+
_ => return Err((lit.span(), "must be a u32 literal".to_string())),
707+
}
708+
}
709+
Ok(local_size)
710+
}
711+
Some(tuple) if tuple.is_empty() => Err((
712+
arg.span,
713+
"#[spirv(compute(threads(x, y, z)))] must have the x dimension specified, trailing ones may be elided".to_string(),
714+
)),
715+
Some(tuple) if tuple.len() > 3 => Err((
716+
arg.span,
717+
"#[spirv(compute(threads(x, y, z)))] is three dimensional".to_string(),
718+
)),
719+
_ => Err((
720+
arg.span,
721+
"#[spirv(compute(threads(x, y, z)))] must have 1 to 3 parameters, trailing ones may be elided".to_string(),
722+
)),
723+
}
724+
}
725+
693726
// for a given entry, gather up the additional attributes
694727
// in this case ExecutionMode's, some have extra arguments
695728
// others are specified with x, y, or z components
@@ -715,30 +748,23 @@ fn parse_entry_attrs(
715748
{
716749
use ExecutionModeExtraDim::*;
717750
let val = match extra_dim {
718-
None => Option::None,
751+
None | Tuple => Option::None,
719752
_ => Some(parse_attr_int_value(attr)?),
720753
};
721754
match execution_mode {
722755
OriginUpperLeft | OriginLowerLeft => {
723756
origin_mode.replace(*execution_mode);
724757
}
725758
LocalSize => {
726-
let val = val.unwrap();
727759
if local_size.is_none() {
728-
local_size.replace([1, 1, 1]);
729-
}
730-
let local_size = local_size.as_mut().unwrap();
731-
match extra_dim {
732-
X => {
733-
local_size[0] = val;
734-
}
735-
Y => {
736-
local_size[1] = val;
737-
}
738-
Z => {
739-
local_size[2] = val;
740-
}
741-
_ => unreachable!(),
760+
local_size.replace(parse_local_size_attr(attr)?);
761+
} else {
762+
return Err((
763+
attr_name.span,
764+
String::from(
765+
"`#[spirv(compute(threads))]` may only be specified once",
766+
),
767+
));
742768
}
743769
}
744770
LocalSizeHint => {
@@ -838,10 +864,18 @@ fn parse_entry_attrs(
838864
.push((origin_mode, ExecutionModeExtra::new([])));
839865
}
840866
GLCompute => {
841-
let local_size = local_size.unwrap_or([1, 1, 1]);
842-
entry
843-
.execution_modes
844-
.push((LocalSize, ExecutionModeExtra::new(local_size)));
867+
if let Some(local_size) = local_size {
868+
entry
869+
.execution_modes
870+
.push((LocalSize, ExecutionModeExtra::new(local_size)));
871+
} else {
872+
return Err((
873+
arg.span(),
874+
String::from(
875+
"The `threads` argument must be specified when using `#[spirv(compute)]`",
876+
),
877+
));
878+
}
845879
}
846880
Kernel => {
847881
if let Some(local_size) = local_size {

examples/shaders/compute-shader/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ extern crate spirv_std;
1313
#[macro_use]
1414
pub extern crate spirv_std_macros;
1515

16-
#[spirv(gl_compute)]
16+
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
17+
#[spirv(compute(threads(32)))]
1718
pub fn main_cs() {}

0 commit comments

Comments
 (0)