Skip to content

Commit bca7656

Browse files
authored
Added a transformation that gets rid of temporary composites. (#690)
* Added an optimization that gets rid of temporary composites. Those temporary composites result from inlining of multi-argument closures. Not only are they rather useless, they're also sometimes invalid, when an argument to said closure is e.g. a pointer. * Correctness fixes to transitive unused removal: - delay only if the instruction is in reference set - properly mark composites being inserted into composites as used * cargo fmt * clippy * Make transformation per-function & rely on DCE for eliminating dead constructs. * Forgot to mark CompositeInsert as pure & additional line cleaning * Rustfmt * Remove duplicate lines only once
1 parent d548268 commit bca7656

File tree

7 files changed

+122
-25
lines changed

7 files changed

+122
-25
lines changed

crates/rustc_codegen_spirv/src/link.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ fn do_link(
541541
dce: env::var("NO_DCE").is_err(),
542542
compact_ids: env::var("NO_COMPACT_IDS").is_err(),
543543
inline: legalize,
544+
destructure: legalize,
544545
mem2reg: legalize,
545546
structurize: env::var("NO_STRUCTURIZE").is_err(),
546547
emit_multiple_modules: cg_args.module_output_type == ModuleOutputType::Multiple,

crates/rustc_codegen_spirv/src/linker/dce.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ fn instruction_is_pure(inst: &Instruction) -> bool {
162162
| InBoundsPtrAccessChain
163163
| CompositeConstruct
164164
| CompositeExtract
165+
| CompositeInsert
165166
| CopyObject
166167
| Transpose
167168
| ConvertFToU
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//! Simplify `OpCompositeExtract` pointing to `OpCompositeConstruct`s / `OpCompositeInsert`s.
2+
//! Such constructions arise after inlining, when using multi-argument closures
3+
//! (and other `Fn*` trait implementations). These composites can frequently be invalid,
4+
//! containing pointers, `OpFunctionArgument`s, etc. After simplification, components
5+
//! will become valid targets for `OpLoad`/`OpStore`.
6+
use super::apply_rewrite_rules;
7+
use rspirv::dr::{Function, Instruction};
8+
use rspirv::spirv::Op;
9+
use rustc_data_structures::fx::FxHashMap;
10+
11+
pub fn destructure_composites(function: &mut Function) {
12+
let mut rewrite_rules = FxHashMap::default();
13+
let reference: FxHashMap<_, _> = function
14+
.all_inst_iter()
15+
.filter_map(|inst| match inst.class.opcode {
16+
Op::CompositeConstruct => Some((inst.result_id.unwrap(), inst.clone())),
17+
Op::CompositeInsert if inst.operands.len() == 3 => {
18+
Some((inst.result_id.unwrap(), inst.clone()))
19+
}
20+
_ => None,
21+
})
22+
.collect();
23+
for inst in function.all_inst_iter_mut() {
24+
if inst.class.opcode == Op::CompositeExtract && inst.operands.len() == 2 {
25+
let mut composite = inst.operands[0].unwrap_id_ref();
26+
let index = inst.operands[1].unwrap_literal_int32();
27+
28+
let origin = loop {
29+
if let Some(inst) = reference.get(&composite) {
30+
match inst.class.opcode {
31+
Op::CompositeInsert => {
32+
let insert_index = inst.operands[2].unwrap_literal_int32();
33+
if insert_index == index {
34+
break Some(inst.operands[0].unwrap_id_ref());
35+
}
36+
composite = inst.operands[1].unwrap_id_ref();
37+
}
38+
Op::CompositeConstruct => {
39+
break inst.operands.get(index as usize).map(|o| o.unwrap_id_ref());
40+
}
41+
_ => unreachable!(),
42+
}
43+
} else {
44+
break None;
45+
}
46+
};
47+
48+
if let Some(origin_id) = origin {
49+
rewrite_rules.insert(
50+
inst.result_id.unwrap(),
51+
rewrite_rules.get(&origin_id).map_or(origin_id, |id| *id),
52+
);
53+
*inst = Instruction::new(Op::Nop, None, None, vec![]);
54+
continue;
55+
}
56+
}
57+
}
58+
59+
// Transitive closure computation
60+
let mut closed_rewrite_rules = rewrite_rules.clone();
61+
for (_, value) in closed_rewrite_rules.iter_mut() {
62+
while let Some(next) = rewrite_rules.get(value) {
63+
*value = *next;
64+
}
65+
}
66+
67+
// Remove instructions replaced by NOPs, as well as unused composite values.
68+
for block in function.blocks.iter_mut() {
69+
block
70+
.instructions
71+
.retain(|inst| inst.class.opcode != Op::Nop);
72+
}
73+
apply_rewrite_rules(&closed_rewrite_rules, &mut function.blocks);
74+
}

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
mod test;
33

44
mod dce;
5+
mod destructure_composites;
56
mod duplicates;
67
mod import_export_link;
78
mod inline;
@@ -27,6 +28,7 @@ pub struct Options {
2728
pub dce: bool,
2829
pub inline: bool,
2930
pub mem2reg: bool,
31+
pub destructure: bool,
3032
pub structurize: bool,
3133
pub emit_multiple_modules: bool,
3234
pub name_variables: bool,
@@ -228,6 +230,10 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
228230
// mem2reg produces minimal SSA form, not pruned, so DCE the dead ones
229231
dce::dce_phi(func);
230232
}
233+
if opts.destructure {
234+
let _timer = sess.timer("link_destructure");
235+
destructure_composites::destructure_composites(func);
236+
}
231237
}
232238
}
233239

@@ -240,11 +246,6 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
240246
}
241247
}
242248

243-
{
244-
let _timer = sess.timer("link_remove_duplicate_lines");
245-
duplicates::remove_duplicate_lines(&mut output);
246-
}
247-
248249
if opts.name_variables {
249250
let _timer = sess.timer("link_name_variables");
250251
simple_passes::name_variables_pass(&mut output);
@@ -289,6 +290,11 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
289290
dce::dce(output);
290291
}
291292

293+
{
294+
let _timer = sess.timer("link_remove_duplicate_lines");
295+
duplicates::remove_duplicate_lines(output);
296+
}
297+
292298
if opts.compact_ids {
293299
let _timer = sess.timer("link_compact_ids");
294300
// compact the ids https://github.com/KhronosGroup/SPIRV-Tools/blob/e02f178a716b0c3c803ce31b9df4088596537872/source/opt/compact_ids_pass.cpp#L43

crates/rustc_codegen_spirv/src/linker/test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ fn assemble_and_link(binaries: &[&[u8]]) -> Result<Module, String> {
9292
compact_ids: true,
9393
dce: false,
9494
inline: false,
95+
destructure: false,
9596
mem2reg: false,
9697
structurize: false,
9798
emit_multiple_modules: false,

tests/ui/dis/index_user_dst.stderr

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,32 @@
33
OpLine %5 7 12
44
%6 = OpAccessChain %7 %8 %9
55
%10 = OpArrayLength %11 %8 0
6-
OpLine %5 7 0
7-
%12 = OpCompositeInsert %13 %6 %14 0
86
OpLine %5 8 21
9-
%15 = OpULessThan %16 %9 %10
7+
%12 = OpULessThan %13 %9 %10
108
OpLine %5 8 21
11-
OpSelectionMerge %17 None
12-
OpBranchConditional %15 %18 %19
13-
%18 = OpLabel
9+
OpSelectionMerge %14 None
10+
OpBranchConditional %12 %15 %16
11+
%15 = OpLabel
1412
OpLine %5 8 21
15-
%20 = OpInBoundsAccessChain %21 %6 %9
16-
%22 = OpLoad %23 %20
13+
%17 = OpInBoundsAccessChain %18 %6 %9
14+
%19 = OpLoad %20 %17
1715
OpLine %5 10 1
1816
OpReturn
19-
%19 = OpLabel
17+
%16 = OpLabel
2018
OpLine %5 8 21
21-
OpBranch %24
22-
%24 = OpLabel
19+
OpBranch %21
20+
%21 = OpLabel
21+
OpBranch %22
22+
%22 = OpLabel
23+
%23 = OpPhi %13 %24 %21 %24 %25
24+
OpLoopMerge %26 %25 None
25+
OpBranchConditional %23 %27 %26
26+
%27 = OpLabel
2327
OpBranch %25
2428
%25 = OpLabel
25-
%26 = OpPhi %16 %27 %24 %27 %28
26-
OpLoopMerge %29 %28 None
27-
OpBranchConditional %26 %30 %29
28-
%30 = OpLabel
29-
OpBranch %28
30-
%28 = OpLabel
31-
OpBranch %25
32-
%29 = OpLabel
29+
OpBranch %22
30+
%26 = OpLabel
3331
OpUnreachable
34-
%17 = OpLabel
32+
%14 = OpLabel
3533
OpUnreachable
3634
OpFunctionEnd
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// build-pass
2+
3+
use spirv_std;
4+
5+
fn closure_user<F: FnMut(&u32, u32)>(ptr: &u32, xmax: u32, mut callback: F) {
6+
for i in 0..xmax {
7+
callback(ptr, i);
8+
}
9+
}
10+
11+
#[spirv(fragment)]
12+
pub fn main(ptr: &mut u32) {
13+
closure_user(ptr, 10, |ptr, i| {
14+
if *ptr == i { spirv_std::arch::kill(); }
15+
});
16+
}

0 commit comments

Comments
 (0)