Skip to content

Commit fe5c771

Browse files
authored
Fix ByteAddressableBuffer PassMode::Pair (#837)
1 parent b99fc51 commit fe5c771

File tree

11 files changed

+181
-47
lines changed

11 files changed

+181
-47
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,7 +2188,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
21882188
for (argument, argument_type) in args.iter().zip(argument_types) {
21892189
assert_ty_eq!(self, argument.ty, argument_type);
21902190
}
2191-
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).cloned();
2191+
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).copied();
2192+
let buffer_load_intrinsic = self
2193+
.buffer_load_intrinsic_fn_id
2194+
.borrow()
2195+
.get(&callee_val)
2196+
.copied();
2197+
let buffer_store_intrinsic = self
2198+
.buffer_store_intrinsic_fn_id
2199+
.borrow()
2200+
.get(&callee_val)
2201+
.copied();
21922202
if let Some(libm_intrinsic) = libm_intrinsic {
21932203
let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args);
21942204
if result_type != result.ty {
@@ -2207,18 +2217,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
22072217
// needing to materialize `&core::panic::Location` or `format_args!`.
22082218
self.abort();
22092219
self.undef(result_type)
2210-
} else if self
2211-
.buffer_load_intrinsic_fn_id
2212-
.borrow()
2213-
.contains(&callee_val)
2214-
{
2215-
self.codegen_buffer_load_intrinsic(result_type, args)
2216-
} else if self
2217-
.buffer_store_intrinsic_fn_id
2218-
.borrow()
2219-
.contains(&callee_val)
2220-
{
2221-
self.codegen_buffer_store_intrinsic(args);
2220+
} else if let Some(mode) = buffer_load_intrinsic {
2221+
self.codegen_buffer_load_intrinsic(result_type, args, mode)
2222+
} else if let Some(mode) = buffer_store_intrinsic {
2223+
self.codegen_buffer_store_intrinsic(args, mode);
22222224

22232225
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
22242226
SpirvValue {

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use super::Builder;
2-
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
2+
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
33
use crate::spirv_type::SpirvType;
44
use rspirv::spirv::Word;
55
use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods};
6+
use rustc_errors::ErrorReported;
67
use rustc_span::DUMMY_SP;
7-
use rustc_target::abi::Align;
8+
use rustc_target::abi::call::PassMode;
9+
use rustc_target::abi::{Align, Size};
810

911
impl<'a, 'tcx> Builder<'a, 'tcx> {
1012
fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue {
@@ -168,7 +170,25 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
168170
&mut self,
169171
result_type: Word,
170172
args: &[SpirvValue],
173+
pass_mode: PassMode,
171174
) -> SpirvValue {
175+
match pass_mode {
176+
PassMode::Ignore => {
177+
return SpirvValue {
178+
kind: SpirvValueKind::IllegalTypeUsed(result_type),
179+
ty: result_type,
180+
}
181+
}
182+
// PassMode::Pair is identical to PassMode::Direct - it's returned as a struct
183+
PassMode::Direct(_) | PassMode::Pair(_, _) => (),
184+
PassMode::Cast(_) => {
185+
self.fatal("PassMode::Cast not supported in codegen_buffer_load_intrinsic")
186+
}
187+
PassMode::Indirect { .. } => {
188+
self.fatal("PassMode::Indirect not supported in codegen_buffer_load_intrinsic")
189+
}
190+
}
191+
172192
// Signature: fn load<T>(array: &[u32], index: u32) -> T;
173193
if args.len() != 3 {
174194
self.fatal(&format!(
@@ -184,15 +204,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
184204
self.recurse_load_type(result_type, result_type, array, word_index, 0)
185205
}
186206

187-
fn store_err(&mut self, original_type: Word, value: SpirvValue) {
207+
fn store_err(&mut self, original_type: Word, value: SpirvValue) -> Result<(), ErrorReported> {
188208
let mut err = self.struct_err(&format!(
189-
"Cannot load type {} in an untyped buffer store",
209+
"Cannot store type {} in an untyped buffer store",
190210
self.debug_type(original_type)
191211
));
192212
if original_type != value.ty {
193213
err.note(&format!("due to containing type {}", value.ty));
194214
}
195215
err.emit();
216+
Err(ErrorReported)
196217
}
197218

198219
fn store_u32(
@@ -201,7 +222,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
201222
dynamic_index: SpirvValue,
202223
constant_offset: u32,
203224
value: SpirvValue,
204-
) {
225+
) -> Result<(), ErrorReported> {
205226
let actual_index = if constant_offset != 0 {
206227
let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
207228
self.add(dynamic_index, const_offset_val)
@@ -216,6 +237,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
216237
.unwrap()
217238
.with_type(u32_ptr);
218239
self.store(value, ptr, Align::ONE);
240+
Ok(())
219241
}
220242

221243
#[allow(clippy::too_many_arguments)]
@@ -228,7 +250,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
228250
constant_word_offset: u32,
229251
element: Word,
230252
count: u32,
231-
) {
253+
) -> Result<(), ErrorReported> {
232254
let element_size_bytes = match self.lookup_type(element).sizeof(self) {
233255
Some(size) => size,
234256
None => return self.store_err(original_type, value),
@@ -245,8 +267,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
245267
array,
246268
dynamic_word_index,
247269
constant_word_offset + element_size_words * index,
248-
);
270+
)?;
249271
}
272+
Ok(())
250273
}
251274

252275
fn recurse_store_type(
@@ -256,17 +279,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
256279
array: SpirvValue,
257280
dynamic_word_index: SpirvValue,
258281
constant_word_offset: u32,
259-
) {
282+
) -> Result<(), ErrorReported> {
260283
match self.lookup_type(value.ty) {
261284
SpirvType::Integer(32, signed) => {
262285
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
263286
let value_u32 = self.intcast(value, u32_ty, signed);
264-
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
287+
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
265288
}
266289
SpirvType::Float(32) => {
267290
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
268291
let value_u32 = self.bitcast(value, u32_ty);
269-
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
292+
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
270293
}
271294
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
272295
.store_vec_mat_arr(
@@ -291,7 +314,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
291314
constant_word_offset,
292315
element,
293316
count,
294-
);
317+
)
295318
}
296319
SpirvType::Adt {
297320
size: Some(_),
@@ -310,20 +333,35 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
310333
array,
311334
dynamic_word_index,
312335
constant_word_offset + word_offset,
313-
);
336+
)?;
314337
}
338+
Ok(())
315339
}
316340

317341
_ => self.store_err(original_type, value),
318342
}
319343
}
320344

321345
/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
322-
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue]) {
346+
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue], pass_mode: PassMode) {
323347
// Signature: fn store<T>(array: &[u32], index: u32, value: T);
324-
if args.len() != 4 {
348+
let is_pair = match pass_mode {
349+
// haha shrug
350+
PassMode::Ignore => return,
351+
PassMode::Direct(_) => false,
352+
PassMode::Pair(_, _) => true,
353+
PassMode::Cast(_) => {
354+
self.fatal("PassMode::Cast not supported in codegen_buffer_store_intrinsic")
355+
}
356+
PassMode::Indirect { .. } => {
357+
self.fatal("PassMode::Indirect not supported in codegen_buffer_store_intrinsic")
358+
}
359+
};
360+
let expected_args = if is_pair { 5 } else { 4 };
361+
if args.len() != expected_args {
325362
self.fatal(&format!(
326-
"buffer_store_intrinsic should have 4 args, it has {}",
363+
"buffer_store_intrinsic should have {} args, it has {}",
364+
expected_args,
327365
args.len()
328366
));
329367
}
@@ -332,7 +370,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
332370
let byte_index = args[2];
333371
let two = self.constant_u32(DUMMY_SP, 2);
334372
let word_index = self.lshr(byte_index, two);
335-
let value = args[3];
336-
self.recurse_store_type(value.ty, value, array, word_index, 0);
373+
if is_pair {
374+
let value_one = args[3];
375+
let value_two = args[4];
376+
let one_result = self.recurse_store_type(value_one.ty, value_one, array, word_index, 0);
377+
378+
let size_of_one = self.lookup_type(value_one.ty).sizeof(self);
379+
if one_result.is_ok() && size_of_one != Some(Size::from_bytes(4)) {
380+
self.fatal("Expected PassMode::Pair first element to have size 4");
381+
}
382+
383+
let _ = self.recurse_store_type(value_two.ty, value_two, array, word_index, 1);
384+
} else {
385+
let value = args[3];
386+
let _ = self.recurse_store_type(value.ty, value, array, word_index, 0);
387+
}
337388
}
338389
}

crates/rustc_codegen_spirv/src/codegen_cx/declare.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,16 @@ impl<'tcx> CodegenCx<'tcx> {
120120
self.unroll_loops_decorations.borrow_mut().insert(fn_id);
121121
}
122122
if attrs.buffer_load_intrinsic.is_some() {
123-
self.buffer_load_intrinsic_fn_id.borrow_mut().insert(fn_id);
123+
let mode = fn_abi.ret.mode;
124+
self.buffer_load_intrinsic_fn_id
125+
.borrow_mut()
126+
.insert(fn_id, mode);
124127
}
125128
if attrs.buffer_store_intrinsic.is_some() {
126-
self.buffer_store_intrinsic_fn_id.borrow_mut().insert(fn_id);
129+
let mode = fn_abi.args.last().unwrap().mode;
130+
self.buffer_store_intrinsic_fn_id
131+
.borrow_mut()
132+
.insert(fn_id, mode);
127133
}
128134

129135
let instance_def_id = instance.def_id();

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ impl<'tcx> CodegenCx<'tcx> {
6666
}
6767
// FIXME(eddyb) support these (by just ignoring them) - if there
6868
// is any validation concern, it should be done on the types.
69-
PassMode::Ignore => self.tcx.sess.span_err(
69+
PassMode::Ignore => self.tcx.sess.span_fatal(
7070
hir_param.ty_span,
7171
&format!(
7272
"entry point parameter type not yet supported \

crates/rustc_codegen_spirv/src/codegen_cx/mod.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use rustc_session::Session;
2929
use rustc_span::def_id::{DefId, LOCAL_CRATE};
3030
use rustc_span::symbol::{sym, Symbol};
3131
use rustc_span::{SourceFile, Span, DUMMY_SP};
32-
use rustc_target::abi::call::FnAbi;
32+
use rustc_target::abi::call::{FnAbi, PassMode};
3333
use rustc_target::abi::{HasDataLayout, TargetDataLayout};
3434
use rustc_target::spec::{HasTargetSpec, Target};
3535
use std::cell::{Cell, RefCell};
@@ -66,10 +66,10 @@ pub struct CodegenCx<'tcx> {
6666

6767
/// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`.
6868
pub panic_fn_id: Cell<Option<Word>>,
69-
/// Intrinsic for loading a <T> from a &[u32]
70-
pub buffer_load_intrinsic_fn_id: RefCell<FxHashSet<Word>>,
71-
/// Intrinsic for storing a <T> into a &[u32]
72-
pub buffer_store_intrinsic_fn_id: RefCell<FxHashSet<Word>>,
69+
/// Intrinsic for loading a <T> from a &[u32]. The PassMode is the mode of the <T>.
70+
pub buffer_load_intrinsic_fn_id: RefCell<FxHashMap<Word, PassMode>>,
71+
/// Intrinsic for storing a <T> into a &[u32]. The PassMode is the mode of the <T>.
72+
pub buffer_store_intrinsic_fn_id: RefCell<FxHashMap<Word, PassMode>>,
7373
/// Builtin bounds-checking panics (from MIR `Assert`s) call `#[lang = "panic_bounds_check"]`.
7474
pub panic_bounds_check_fn_id: Cell<Option<Word>>,
7575

crates/spirv-std/src/byte_addressable_buffer.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,14 @@ use core::mem;
55
#[spirv(buffer_load_intrinsic)]
66
#[spirv_std_macros::gpu_only]
77
#[allow(improper_ctypes_definitions)]
8-
unsafe extern "unadjusted" fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T {
8+
unsafe fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T {
99
unimplemented!()
1010
} // actually implemented in the compiler
1111

1212
#[spirv(buffer_store_intrinsic)]
1313
#[spirv_std_macros::gpu_only]
1414
#[allow(improper_ctypes_definitions)]
15-
unsafe extern "unadjusted" fn buffer_store_intrinsic<T>(
16-
_buffer: &mut [u32],
17-
_offset: u32,
18-
_value: T,
19-
) {
15+
unsafe fn buffer_store_intrinsic<T>(_buffer: &mut [u32], _offset: u32, _value: T) {
2016
unimplemented!()
2117
} // actually implemented in the compiler
2218

crates/spirv-std/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#![cfg_attr(
33
target_arch = "spirv",
44
feature(
5-
abi_unadjusted,
65
asm,
76
asm_const,
87
asm_experimental_arch,

tests/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Compiletests
2+
3+
This folder contains tests known as "compiletests". Each file in the `ui` folder corresponds to a
4+
single compiletest. The way they work is a tool iterates over every file, and tries to compile it.
5+
At the start of the file, there's some meta-comments about the expected result of the compile:
6+
whether it should succeed compilation, or fail. If it is expected to fail, there's a corresponding
7+
.stderr file next to the file that contains the expected compiler error message.
8+
9+
The `src` folder here is the tool that iterates over every file in the `ui` folder. It uses the
10+
`compiletests` library, taken from rustc's own compiletest framework.
11+
12+
You can run compiletests via `cargo compiletests`. This is an alias set up in `.cargo/config` for
13+
`cargo run --release -p compiletests --`. You can filter to run specific tests by passing the
14+
(partial) filenames to `cargo compiletests some_file_name`, and update the `.stderr` files to
15+
contain new output via the `--bless` flag (with `--bless`, make sure you're actually supposed to be
16+
changing the .stderr files due to an intentional change, and hand-validate the output is correct
17+
afterwards).
18+
19+
Keep in mind that tests here here are not executed, merely checked for errors (including validating
20+
the resulting binary with spirv-val). Because of this, there might be some strange code in here -
21+
the point isn't to make a fully functional shader every time (that would take an annoying amount of
22+
effort), but rather validate that specific parts of the compiler are doing their job correctly
23+
(either succeeding as they should, or erroring as they should).

tests/ui/arch/debug_printf_type_checking.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ error[E0277]: the trait bound `{float}: Vector<f32, 2_usize>` is not satisfied
9696
<DVec2 as Vector<f64, 2_usize>>
9797
and 13 others
9898
note: required by a bound in `debug_printf_assert_is_vector`
99-
--> $SPIRV_STD_SRC/lib.rs:146:8
99+
--> $SPIRV_STD_SRC/lib.rs:145:8
100100
|
101-
146 | V: crate::vector::Vector<TY, SIZE>,
101+
145 | V: crate::vector::Vector<TY, SIZE>,
102102
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`
103103

104104
error[E0308]: mismatched types
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// build-pass
2+
3+
use spirv_std::ByteAddressableBuffer;
4+
5+
pub struct EmptyStruct {}
6+
7+
#[spirv(fragment)]
8+
pub fn load(
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
10+
#[spirv(flat)] out: &mut EmptyStruct,
11+
) {
12+
unsafe {
13+
let buf = ByteAddressableBuffer::new(buf);
14+
*out = buf.load(5);
15+
}
16+
}
17+
18+
#[spirv(fragment)]
19+
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32]) {
20+
let val = EmptyStruct {};
21+
unsafe {
22+
let mut buf = ByteAddressableBuffer::new(buf);
23+
buf.store(5, val);
24+
}
25+
}

0 commit comments

Comments
 (0)