Skip to content

Commit acda771

Browse files
authored
asm: add support for noreturn option (#717)
* asm: add support for noreturn option OpUnreachable will be appended as terminator at the end of the asm block. * asm: implicit label after return or abort terminator * rework handling * fix tests and add few comments * fix tests
1 parent df5b411 commit acda771

File tree

8 files changed

+167
-21
lines changed

8 files changed

+167
-21
lines changed

crates/rustc_codegen_spirv/src/builder/spirv_asm.rs

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::builder_spirv::{BuilderCursor, SpirvValue};
33
use crate::codegen_cx::CodegenCx;
44
use crate::spirv_type::SpirvType;
55
use rspirv::dr;
6-
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier};
6+
use rspirv::grammar::{reflect, LogicalOperand, OperandKind, OperandQuantifier};
77
use rspirv::spirv::{
88
FPFastMathMode, FragmentShadingRate, FunctionControl, ImageOperands, KernelProfilingInfo,
99
LoopControl, MemoryAccess, MemorySemantics, Op, RayFlags, SelectionControl, StorageClass, Word,
@@ -70,8 +70,13 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
7070
options: InlineAsmOptions,
7171
_line_spans: &[Span],
7272
) {
73-
if !options.is_empty() {
74-
self.err(&format!("asm flags not supported: {:?}", options));
73+
const SUPPORTED_OPTIONS: InlineAsmOptions = InlineAsmOptions::NORETURN;
74+
let unsupported_options = options & !SUPPORTED_OPTIONS;
75+
if !unsupported_options.is_empty() {
76+
self.err(&format!(
77+
"asm flags not supported: {:?}",
78+
unsupported_options
79+
));
7580
}
7681
// vec of lines, and each line is vec of tokens
7782
let mut tokens = vec![vec![]];
@@ -141,14 +146,41 @@ impl<'a, 'tcx> AsmBuilderMethods<'tcx> for Builder<'a, 'tcx> {
141146
id_to_type_map.insert(value.def(self), value.ty);
142147
}
143148
}
149+
150+
let mut asm_block = AsmBlock::Open;
144151
for line in tokens {
145152
self.codegen_asm(
146153
&mut id_map,
147154
&mut defined_ids,
148155
&mut id_to_type_map,
156+
&mut asm_block,
149157
line.into_iter(),
150158
);
151159
}
160+
161+
match (options.contains(InlineAsmOptions::NORETURN), asm_block) {
162+
(true, AsmBlock::Open) => {
163+
self.err("`noreturn` requires a terminator at the end");
164+
}
165+
(true, AsmBlock::End(_)) => {
166+
// `noreturn` appends an `OpUnreachable` after the asm block.
167+
// This requires starting a new block for this.
168+
let label = self.emit().id();
169+
self.emit()
170+
.insert_into_block(
171+
dr::InsertPoint::End,
172+
dr::Instruction::new(Op::Label, None, Some(label), vec![]),
173+
)
174+
.unwrap();
175+
}
176+
(false, AsmBlock::Open) => (),
177+
(false, AsmBlock::End(terminator)) => {
178+
self.err(&format!(
179+
"trailing terminator {:?} requires `options(noreturn)`",
180+
terminator
181+
));
182+
}
183+
}
152184
for (id, num) in id_map {
153185
if !defined_ids.contains(&num) {
154186
self.err(&format!("%{} is used but not defined", id));
@@ -178,6 +210,11 @@ enum OutRegister<'a> {
178210
Place(PlaceRef<'a, SpirvValue>),
179211
}
180212

213+
enum AsmBlock {
214+
Open,
215+
End(Op),
216+
}
217+
181218
impl<'cx, 'tcx> Builder<'cx, 'tcx> {
182219
fn lex_word<'a>(&self, line: &mut std::str::Chars<'a>) -> Option<Token<'a, 'cx, 'tcx>> {
183220
loop {
@@ -242,6 +279,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
242279
&mut self,
243280
id_map: &mut FxHashMap<&str, Word>,
244281
defined_ids: &mut FxHashSet<Word>,
282+
asm_block: &mut AsmBlock,
245283
inst: dr::Instruction,
246284
) {
247285
// Types declared must be registered in our type system.
@@ -328,10 +366,32 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
328366
}
329367
return;
330368
}
331-
_ => {
369+
370+
op => {
332371
self.emit()
333372
.insert_into_block(dr::InsertPoint::End, inst)
334373
.unwrap();
374+
375+
*asm_block = match *asm_block {
376+
AsmBlock::Open => {
377+
if reflect::is_block_terminator(op) {
378+
AsmBlock::End(op)
379+
} else {
380+
AsmBlock::Open
381+
}
382+
}
383+
AsmBlock::End(terminator) => {
384+
if op != Op::Label {
385+
self.err(&format!(
386+
"expected OpLabel after terminator {:?}",
387+
terminator
388+
));
389+
}
390+
391+
AsmBlock::Open
392+
}
393+
};
394+
335395
return;
336396
}
337397
};
@@ -351,6 +411,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
351411
id_map: &mut FxHashMap<&'a str, Word>,
352412
defined_ids: &mut FxHashSet<Word>,
353413
id_to_type_map: &mut FxHashMap<Word, Word>,
414+
asm_block: &mut AsmBlock,
354415
mut tokens: impl Iterator<Item = Token<'a, 'cx, 'tcx>>,
355416
) where
356417
'cx: 'a,
@@ -427,7 +488,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
427488
if let Some(result_type) = instruction.result_type {
428489
id_to_type_map.insert(instruction.result_id.unwrap(), result_type);
429490
}
430-
self.insert_inst(id_map, defined_ids, instruction);
491+
self.insert_inst(id_map, defined_ids, asm_block, instruction);
431492
if let Some(OutRegister::Place(place)) = out_register {
432493
self.emit()
433494
.store(

crates/spirv-std/src/arch.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,5 @@ pub unsafe fn vector_insert_dynamic<T: Scalar, V: Vector<T, N>, const N: usize>(
148148
#[doc(alias = "OpKill", alias = "discard")]
149149
#[allow(clippy::empty_loop)]
150150
pub fn kill() -> ! {
151-
unsafe {
152-
asm!("OpKill", "%unused = OpLabel");
153-
}
154-
loop {}
151+
unsafe { asm!("OpKill", options(noreturn)) }
155152
}

crates/spirv-std/src/arch/ray_tracing.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ pub unsafe fn report_intersection(hit: f32, hit_kind: u32) -> bool {
4444
#[inline]
4545
#[allow(clippy::empty_loop)]
4646
pub unsafe fn ignore_intersection() -> ! {
47-
asm!("OpIgnoreIntersectionKHR", "%unused = OpLabel");
48-
loop {}
47+
asm!("OpIgnoreIntersectionKHR", options(noreturn));
4948
}
5049

5150
/// Terminates the invocation that executes it, stops the ray traversal, accepts
@@ -57,8 +56,7 @@ pub unsafe fn ignore_intersection() -> ! {
5756
#[inline]
5857
#[allow(clippy::empty_loop)]
5958
pub unsafe fn terminate_ray() -> ! {
60-
asm!("OpTerminateRayKHR", "%unused = OpLabel");
61-
loop {}
59+
asm!("OpTerminateRayKHR", options(noreturn));
6260
}
6361

6462
/// Invoke a callable shader.

crates/spirv-std/src/ray_tracing.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ impl AccelerationStructure {
2525
"%ret = OpTypeAccelerationStructureKHR",
2626
"%result = OpConvertUToAccelerationStructureKHR %ret {id}",
2727
"OpReturnValue %result",
28-
"%blah = OpLabel",
2928
id = in(reg) id,
29+
options(noreturn)
3030
}
31-
loop {}
3231
}
3332

3433
/// Converts a vector of two 32 bit integers into an [`AccelerationStructure`].
@@ -47,10 +46,9 @@ impl AccelerationStructure {
4746
"%id = OpLoad _ {id}",
4847
"%result = OpConvertUToAccelerationStructureKHR %ret %id",
4948
"OpReturnValue %result",
50-
"%blah = OpLabel",
5149
id = in(reg) &id,
50+
options(noreturn),
5251
}
53-
loop {}
5452
}
5553

5654
#[spirv_std_macros::gpu_only]

crates/spirv-std/src/runtime_array.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ impl<T> RuntimeArray<T> {
1717
asm! {
1818
"%result = OpAccessChain _ {arr} {index}",
1919
"OpReturnValue %result",
20-
"%unused = OpLabel",
2120
arr = in(reg) self,
2221
index = in(reg) index,
22+
options(noreturn),
2323
}
24-
loop {}
2524
}
2625

2726
#[spirv_std_macros::gpu_only]
@@ -30,10 +29,9 @@ impl<T> RuntimeArray<T> {
3029
asm! {
3130
"%result = OpAccessChain _ {arr} {index}",
3231
"OpReturnValue %result",
33-
"%unused = OpLabel",
3432
arr = in(reg) self,
3533
index = in(reg) index,
34+
options(noreturn),
3635
}
37-
loop {}
3836
}
3937
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Tests validating tracking of basic blocks
2+
// within the `asm!` macro.
3+
// build-fail
4+
5+
use spirv_std as _;
6+
7+
// Active basic block with `noreturn`.
8+
fn asm_noreturn_open() {
9+
unsafe {
10+
asm!("", options(noreturn));
11+
}
12+
}
13+
14+
// No active basic block without `noreturn`.
15+
fn asm_closed() {
16+
unsafe {
17+
asm!(
18+
"OpUnreachable",
19+
);
20+
}
21+
}
22+
23+
// Invalid op after terminator
24+
fn asm_invalid_op_terminator(x: f32) {
25+
unsafe {
26+
asm!(
27+
"OpKill",
28+
"%sum = OpFAdd _ {x} {x}",
29+
x = in(reg) x,
30+
);
31+
}
32+
}
33+
34+
#[spirv(fragment)]
35+
pub fn main() {
36+
asm_closed();
37+
asm_noreturn_open();
38+
asm_invalid_op_terminator(1.0);
39+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
error: `noreturn` requires a terminator at the end
2+
--> $DIR/block_tracking_fail.rs:10:9
3+
|
4+
10 | asm!("", options(noreturn));
5+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6+
7+
error: trailing terminator Unreachable requires `options(noreturn)`
8+
--> $DIR/block_tracking_fail.rs:17:9
9+
|
10+
17 | / asm!(
11+
18 | | "OpUnreachable",
12+
19 | | );
13+
| |__________^
14+
15+
error: expected OpLabel after terminator Kill
16+
--> $DIR/block_tracking_fail.rs:26:9
17+
|
18+
26 | / asm!(
19+
27 | | "OpKill",
20+
28 | | "%sum = OpFAdd _ {x} {x}",
21+
29 | | x = in(reg) x,
22+
30 | | );
23+
| |__________^
24+
25+
error: aborting due to 3 previous errors
26+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Tests validating tracking of basic blocks
2+
// within the `asm!` macro.
3+
// build-pass
4+
5+
use spirv_std as _;
6+
7+
fn asm_label() {
8+
unsafe {
9+
asm!(
10+
"OpReturn", // close active block
11+
"%unused = OpLabel", // open new block
12+
);
13+
}
14+
}
15+
16+
fn asm_noreturn_single() -> ! {
17+
unsafe {
18+
asm!(
19+
"OpKill", // close active block
20+
options(noreturn),
21+
);
22+
}
23+
}
24+
25+
#[spirv(fragment)]
26+
pub fn main() {
27+
asm_label();
28+
asm_noreturn_single();
29+
}

0 commit comments

Comments
 (0)