Skip to content

Commit 340f4bb

Browse files
authored
Error when int doesn't have spirv(flat) (#815)
1 parent 0652153 commit 340f4bb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+127
-69
lines changed

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,12 @@ impl<'tcx> CodegenCx<'tcx> {
517517
);
518518
}
519519

520-
self.check_for_bools(
520+
self.check_for_bad_types(
521521
hir_param.ty_span,
522522
var_ptr_spirv_type,
523523
storage_class,
524524
attrs.builtin.is_some(),
525+
attrs.flat.is_some(),
525526
);
526527

527528
// Assign locations from left to right, incrementing each storage class
@@ -563,7 +564,15 @@ impl<'tcx> CodegenCx<'tcx> {
563564
}
564565

565566
// Booleans are only allowed in some storage classes. Error if they're in others.
566-
fn check_for_bools(&self, span: Span, ty: Word, storage_class: StorageClass, is_builtin: bool) {
567+
// Integers and f64s must be decorated with `#[spirv(flat)]`.
568+
fn check_for_bad_types(
569+
&self,
570+
span: Span,
571+
ty: Word,
572+
storage_class: StorageClass,
573+
is_builtin: bool,
574+
is_flat: bool,
575+
) {
567576
// private and function are allowed here, but they can't happen.
568577
// SPIR-V technically allows all input/output variables to be booleans, not just builtins,
569578
// but has a note:
@@ -578,28 +587,49 @@ impl<'tcx> CodegenCx<'tcx> {
578587
{
579588
return;
580589
}
581-
if recurse(self, ty) {
590+
let mut has_bool = false;
591+
let mut must_be_flat = false;
592+
recurse(self, ty, &mut has_bool, &mut must_be_flat);
593+
if has_bool {
582594
self.tcx
583595
.sess
584596
.span_err(span, "entrypoint parameter cannot contain a boolean");
585597
}
586-
fn recurse(cx: &CodegenCx<'_>, ty: Word) -> bool {
598+
if matches!(storage_class, StorageClass::Input | StorageClass::Output)
599+
&& must_be_flat
600+
&& !is_flat
601+
{
602+
self.tcx
603+
.sess
604+
.span_err(span, "parameter must be decorated with #[spirv(flat)]");
605+
}
606+
fn recurse(cx: &CodegenCx<'_>, ty: Word, has_bool: &mut bool, must_be_flat: &mut bool) {
587607
match cx.lookup_type(ty) {
588-
SpirvType::Bool => true,
589-
SpirvType::Adt { field_types, .. } => field_types.iter().any(|&f| recurse(cx, f)),
608+
SpirvType::Bool => *has_bool = true,
609+
SpirvType::Integer(_, _) | SpirvType::Float(64) => *must_be_flat = true,
610+
SpirvType::Adt { field_types, .. } => {
611+
for f in field_types {
612+
recurse(cx, f, has_bool, must_be_flat);
613+
}
614+
}
590615
SpirvType::Vector { element, .. }
591616
| SpirvType::Matrix { element, .. }
592617
| SpirvType::Array { element, .. }
593618
| SpirvType::RuntimeArray { element }
594619
| SpirvType::Pointer { pointee: element }
595620
| SpirvType::InterfaceBlock {
596621
inner_type: element,
597-
} => recurse(cx, element),
622+
} => recurse(cx, element, has_bool, must_be_flat),
598623
SpirvType::Function {
599624
return_type,
600625
arguments,
601-
} => recurse(cx, return_type) || arguments.iter().any(|&a| recurse(cx, a)),
602-
_ => false,
626+
} => {
627+
recurse(cx, return_type, has_bool, must_be_flat);
628+
for a in arguments {
629+
recurse(cx, a, has_bool, must_be_flat);
630+
}
631+
}
632+
_ => (),
603633
}
604634
}
605635
}

tests/ui/byte_addressable_buffer/arr.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use spirv_std::{glam::Vec4, ByteAddressableBuffer};
55
#[spirv(fragment)]
66
pub fn load(
77
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
8-
out: &mut [i32; 4],
8+
#[spirv(flat)] out: &mut [i32; 4],
99
) {
1010
unsafe {
1111
let buf = ByteAddressableBuffer::new(buf);
@@ -16,7 +16,7 @@ pub fn load(
1616
#[spirv(fragment)]
1717
pub fn store(
1818
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
19-
val: [i32; 4],
19+
#[spirv(flat)] val: [i32; 4],
2020
) {
2121
unsafe {
2222
let mut buf = ByteAddressableBuffer::new(buf);

tests/ui/byte_addressable_buffer/big_struct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pub struct BigStruct {
1414
#[spirv(fragment)]
1515
pub fn load(
1616
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
17-
out: &mut BigStruct,
17+
#[spirv(flat)] out: &mut BigStruct,
1818
) {
1919
unsafe {
2020
let buf = ByteAddressableBuffer::new(buf);
@@ -25,7 +25,7 @@ pub fn load(
2525
#[spirv(fragment)]
2626
pub fn store(
2727
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
28-
val: BigStruct,
28+
#[spirv(flat)] val: BigStruct,
2929
) {
3030
unsafe {
3131
let mut buf = ByteAddressableBuffer::new(buf);

tests/ui/byte_addressable_buffer/u32.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use spirv_std::ByteAddressableBuffer;
55
#[spirv(fragment)]
66
pub fn load(
77
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
8-
out: &mut u32,
8+
#[spirv(flat)] out: &mut u32,
99
) {
1010
unsafe {
1111
let buf = ByteAddressableBuffer::new(buf);
@@ -14,7 +14,10 @@ pub fn load(
1414
}
1515

1616
#[spirv(fragment)]
17-
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: u32) {
17+
pub fn store(
18+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
19+
#[spirv(flat)] val: u32,
20+
) {
1821
unsafe {
1922
let mut buf = ByteAddressableBuffer::new(buf);
2023
buf.store(5, val);

tests/ui/dis/pass-mode-cast-struct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ impl Foo {
2424
}
2525

2626
#[spirv(fragment)]
27-
pub fn main(in_packed: u64, out_sum: &mut u32) {
27+
pub fn main(#[spirv(flat)] in_packed: u64, #[spirv(flat)] out_sum: &mut u32) {
2828
let foo = Foo::unpack(in_packed);
2929
*out_sum = foo.a + (foo.b + foo.c) as u32;
3030
}

tests/ui/image/query/query_levels.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
66
#[spirv(fragment)]
77
pub fn main(
88
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled),
9-
output: &mut u32,
9+
#[spirv(flat)] output: &mut u32,
1010
) {
1111
*output = image.query_levels();
1212
}

tests/ui/image/query/query_samples.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
66
#[spirv(fragment)]
77
pub fn main(
88
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled, multisampled),
9-
output: &mut u32,
9+
#[spirv(flat)] output: &mut u32,
1010
) {
1111
*output = image.query_samples();
1212
}

tests/ui/image/query/query_size.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
66
#[spirv(fragment)]
77
pub fn main(
88
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled=false),
9-
output: &mut glam::UVec2,
9+
#[spirv(flat)] output: &mut glam::UVec2,
1010
) {
1111
*output = image.query_size();
1212
}

tests/ui/image/query/query_size_lod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
66
#[spirv(fragment)]
77
pub fn main(
88
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled),
9-
output: &mut glam::UVec2,
9+
#[spirv(flat)] output: &mut glam::UVec2,
1010
) {
1111
*output = image.query_size_lod(0);
1212
}

tests/ui/lang/asm/infer-access-chain-array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use spirv_std as _;
88
use glam::Vec4;
99

1010
#[spirv(fragment)]
11-
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], i: u32, out: &mut Vec4) {
11+
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], #[spirv(flat)] i: u32, out: &mut Vec4) {
1212
unsafe {
1313
asm!(
1414
"%val_ptr = OpAccessChain _ {array_ptr} {index}",

0 commit comments

Comments
 (0)