Skip to content

Commit 3e1dbab

Browse files
committed
linker/inline: upgrade cycle counting to topo sort & resolve quadratic inlining.
By inlining in callee -> caller order, we avoid the need to continue inlining the code we just inlined. A simple reachability test from one of the entry points helps avoid unnecessary work as well. The algorithm however remains quadratic in case where OpAccessChains repeatedly find their way into function parameters. There are two ways out: either a more complex control flow analysis, or conservatively inlining all function calls which reference FunctionParameters as arguments. I don't think either case is very worth it.
1 parent 03f89e8 commit 3e1dbab

File tree

1 file changed

+136
-114
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+136
-114
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 136 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -14,43 +14,31 @@ use rustc_errors::ErrorGuaranteed;
1414
use rustc_session::Session;
1515
use std::mem::take;
1616

17-
type FunctionMap = FxHashMap<Word, Function>;
17+
type FunctionMap = FxHashMap<Word, usize>;
1818

1919
pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
20+
let (disallowed_argument_types, disallowed_return_types) =
21+
compute_disallowed_argument_and_return_types(module);
22+
let mut to_delete: Vec<_> = module
23+
.functions
24+
.iter()
25+
.map(|f| should_inline(&disallowed_argument_types, &disallowed_return_types, f))
26+
.collect();
2027
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
21-
deny_recursion_in_module(sess, module)?;
22-
28+
let postorder = compute_function_postorder(sess, module, &mut to_delete)?;
2329
let functions = module
2430
.functions
2531
.iter()
26-
.map(|f| (f.def_id().unwrap(), f.clone()))
32+
.enumerate()
33+
.map(|(idx, f)| (f.def_id().unwrap(), idx))
2734
.collect();
28-
let (disallowed_argument_types, disallowed_return_types) =
29-
compute_disallowed_argument_and_return_types(module);
3035
let void = module
3136
.types_global_values
3237
.iter()
3338
.find(|inst| inst.class.opcode == Op::TypeVoid)
3439
.map_or(0, |inst| inst.result_id.unwrap());
3540
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
3641
// inlines in functions that will get inlined)
37-
let mut dropped_ids = FxHashSet::default();
38-
module.functions.retain(|f| {
39-
if should_inline(&disallowed_argument_types, &disallowed_return_types, f) {
40-
// TODO: We should insert all defined IDs in this function.
41-
dropped_ids.insert(f.def_id().unwrap());
42-
false
43-
} else {
44-
true
45-
}
46-
});
47-
// Drop OpName etc. for inlined functions
48-
module.debug_names.retain(|inst| {
49-
!inst.operands.iter().any(|op| {
50-
op.id_ref_any()
51-
.map_or(false, |id| dropped_ids.contains(&id))
52-
})
53-
});
5442
let mut inliner = Inliner {
5543
header: module.header.as_mut().unwrap(),
5644
types_global_values: &mut module.types_global_values,
@@ -59,76 +47,121 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
5947
disallowed_argument_types: &disallowed_argument_types,
6048
disallowed_return_types: &disallowed_return_types,
6149
};
62-
for function in &mut module.functions {
63-
inliner.inline_fn(function);
64-
fuse_trivial_branches(function);
50+
for index in postorder {
51+
inliner.inline_fn(&mut module.functions, index);
52+
fuse_trivial_branches(&mut module.functions[index]);
53+
}
54+
let mut dropped_ids = FxHashSet::default();
55+
for i in (0..module.functions.len()).rev() {
56+
if to_delete[i] {
57+
dropped_ids.insert(module.functions.remove(i).def_id().unwrap());
58+
}
6559
}
60+
// Drop OpName etc. for inlined functions
61+
module.debug_names.retain(|inst| {
62+
!inst.operands.iter().any(|op| {
63+
op.id_ref_any()
64+
.map_or(false, |id| dropped_ids.contains(&id))
65+
})
66+
});
6667
Ok(())
6768
}
6869

69-
// https://stackoverflow.com/a/53995651
70-
fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()> {
70+
/// Topological sorting algorithm due to T. Cormen
71+
/// Starts from module's entry points, so only reachable functions will be returned
72+
/// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`,
73+
/// `to_delete[i]` is set to true.
74+
fn compute_function_postorder(
75+
sess: &Session,
76+
module: &Module,
77+
to_delete: &mut [bool],
78+
) -> super::Result<Vec<usize>> {
7179
let func_to_index: FxHashMap<Word, usize> = module
7280
.functions
7381
.iter()
7482
.enumerate()
7583
.map(|(index, func)| (func.def_id().unwrap(), index))
7684
.collect();
77-
let mut discovered = vec![false; module.functions.len()];
78-
let mut finished = vec![false; module.functions.len()];
85+
/// Possible node states for cycle-discovering DFS.
86+
#[derive(Clone, PartialEq)]
87+
enum NodeState {
88+
/// Normal, not visited.
89+
NotVisited,
90+
/// Currently being visited.
91+
Discovered,
92+
/// DFS returned.
93+
Finished,
94+
/// Not visited, entry point.
95+
Entry,
96+
}
97+
let mut states = vec![NodeState::NotVisited; module.functions.len()];
98+
for opep in module.entry_points.iter() {
99+
let func_id = opep.operands[1].unwrap_id_ref();
100+
states[func_to_index[&func_id]] = NodeState::Entry;
101+
}
79102
let mut has_recursion = None;
103+
let mut postorder = vec![];
80104
for index in 0..module.functions.len() {
81-
if !discovered[index] && !finished[index] {
105+
if NodeState::Entry == states[index] {
82106
visit(
83107
sess,
84108
module,
85109
index,
86-
&mut discovered,
87-
&mut finished,
110+
&mut states[..],
88111
&mut has_recursion,
112+
&mut postorder,
89113
&func_to_index,
90114
);
91115
}
92116
}
93117

118+
for index in 0..module.functions.len() {
119+
if NodeState::NotVisited == states[index] {
120+
to_delete[index] = true;
121+
}
122+
}
123+
94124
fn visit(
95125
sess: &Session,
96126
module: &Module,
97127
current: usize,
98-
discovered: &mut Vec<bool>,
99-
finished: &mut Vec<bool>,
128+
states: &mut [NodeState],
100129
has_recursion: &mut Option<ErrorGuaranteed>,
130+
postorder: &mut Vec<usize>,
101131
func_to_index: &FxHashMap<Word, usize>,
102132
) {
103-
discovered[current] = true;
133+
states[current] = NodeState::Discovered;
104134

105135
for next in calls(&module.functions[current], func_to_index) {
106-
if discovered[next] {
107-
let names = get_names(module);
108-
let current_name = get_name(&names, module.functions[current].def_id().unwrap());
109-
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
110-
*has_recursion = Some(sess.err(&format!(
111-
"module has recursion, which is not allowed: `{}` calls `{}`",
112-
current_name, next_name
113-
)));
114-
break;
115-
}
116-
117-
if !finished[next] {
118-
visit(
119-
sess,
120-
module,
121-
next,
122-
discovered,
123-
finished,
124-
has_recursion,
125-
func_to_index,
126-
);
136+
match states[next] {
137+
NodeState::Discovered => {
138+
let names = get_names(module);
139+
let current_name =
140+
get_name(&names, module.functions[current].def_id().unwrap());
141+
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
142+
*has_recursion = Some(sess.err(&format!(
143+
"module has recursion, which is not allowed: `{}` calls `{}`",
144+
current_name, next_name
145+
)));
146+
break;
147+
}
148+
NodeState::NotVisited | NodeState::Entry => {
149+
visit(
150+
sess,
151+
module,
152+
next,
153+
states,
154+
has_recursion,
155+
postorder,
156+
func_to_index,
157+
);
158+
}
159+
NodeState::Finished => {}
127160
}
128161
}
129162

130-
discovered[current] = false;
131-
finished[current] = true;
163+
states[current] = NodeState::Finished;
164+
postorder.push(current)
132165
}
133166

134167
fn calls<'a>(
@@ -146,7 +179,7 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()
146179

147180
match has_recursion {
148181
Some(err) => Err(err),
149-
None => Ok(()),
182+
None => Ok(postorder),
150183
}
151184
}
152185

@@ -284,33 +317,39 @@ impl Inliner<'_, '_> {
284317
inst_id
285318
}
286319

287-
fn inline_fn(&mut self, function: &mut Function) {
320+
fn inline_fn(&mut self, functions: &mut [Function], index: usize) {
321+
let mut function = take(&mut functions[index]);
288322
let mut block_idx = 0;
289323
while block_idx < function.blocks.len() {
290-
// If we successfully inlined a block, then repeat processing on the same block, in
291-
// case the newly inlined block has more inlined calls.
292-
// TODO: This is quadratic
293-
if !self.inline_block(function, block_idx) {
294-
block_idx += 1;
295-
}
324+
// If we successfully inlined a block, then continue processing on the next block or its tail.
325+
// TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments.
326+
// For the common case of "we knew which functions to inline", it is linear.
327+
self.inline_block(&mut function, &functions, block_idx);
328+
block_idx += 1;
296329
}
330+
functions[index] = function;
297331
}
298332

299-
fn inline_block(&mut self, caller: &mut Function, block_idx: usize) -> bool {
333+
/// Inlines one block and returns whether inlining actually occurred.
334+
/// After calling this, blocks[block_idx] is finished processing.
335+
fn inline_block(
336+
&mut self,
337+
caller: &mut Function,
338+
functions: &[Function],
339+
block_idx: usize,
340+
) -> bool {
300341
// Find the first inlined OpFunctionCall
301342
let call = caller.blocks[block_idx]
302343
.instructions
303344
.iter()
304345
.enumerate()
305346
.filter(|(_, inst)| inst.class.opcode == Op::FunctionCall)
306347
.map(|(index, inst)| {
307-
(
308-
index,
309-
inst,
310-
self.functions
311-
.get(&inst.operands[0].id_ref_any().unwrap())
312-
.unwrap(),
313-
)
348+
let idx = self
349+
.functions
350+
.get(&inst.operands[0].id_ref_any().unwrap())
351+
.unwrap();
352+
(index, inst, &functions[*idx])
314353
})
315354
.find(|(_, inst, f)| {
316355
should_inline(
@@ -375,17 +414,23 @@ impl Inliner<'_, '_> {
375414
);
376415
}
377416

378-
// Fuse the first block of the callee into the block of the caller. This is okay because
379-
// it's illegal to branch to the first BB in a function.
380-
let mut callee_header = inlined_blocks.remove(0).instructions;
417+
// Move the variables over from the inlined function to here.
418+
let mut callee_header = take(&mut inlined_blocks[0]).instructions;
381419
// TODO: OpLine handling
382-
let num_variables = callee_header
383-
.iter()
384-
.position(|inst| inst.class.opcode != Op::Variable)
385-
.unwrap_or(callee_header.len());
386-
caller.blocks[block_idx]
387-
.instructions
388-
.append(&mut callee_header.split_off(num_variables));
420+
let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable);
421+
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
422+
// it, and we maintain the invariant that current block has finished processing.
423+
let jump_to = self.id();
424+
inlined_blocks[0] = Block {
425+
label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])),
426+
instructions: callee_header.split_off(num_variables),
427+
};
428+
caller.blocks[block_idx].instructions.push(Instruction::new(
429+
Op::Branch,
430+
None,
431+
None,
432+
vec![Operand::IdRef(jump_to)],
433+
));
389434
// Move the OpVariables of the callee to the caller.
390435
insert_opvariables(&mut caller.blocks[0], callee_header);
391436

@@ -467,45 +512,22 @@ fn get_inlined_blocks(
467512
fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) {
468513
let index = block
469514
.instructions
470-
.iter()
471-
.enumerate()
472-
.find_map(|(index, inst)| {
473-
if inst.class.opcode != Op::Variable {
474-
Some(index)
475-
} else {
476-
None
477-
}
478-
});
515+
.partition_point(|inst| inst.class.opcode == Op::Variable);
516+
479517
let inst = Instruction::new(
480518
Op::Variable,
481519
Some(ptr_ty),
482520
Some(result_id),
483521
vec![Operand::StorageClass(StorageClass::Function)],
484522
);
485-
match index {
486-
Some(index) => block.instructions.insert(index, inst),
487-
None => block.instructions.push(inst),
488-
}
523+
block.instructions.insert(index, inst)
489524
}
490525

491-
fn insert_opvariables(block: &mut Block, mut insts: Vec<Instruction>) {
526+
fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
492527
let index = block
493528
.instructions
494-
.iter()
495-
.enumerate()
496-
.find_map(|(index, inst)| {
497-
if inst.class.opcode != Op::Variable {
498-
Some(index)
499-
} else {
500-
None
501-
}
502-
});
503-
match index {
504-
Some(index) => {
505-
block.instructions.splice(index..index, insts);
506-
}
507-
None => block.instructions.append(&mut insts),
508-
}
529+
.partition_point(|inst| inst.class.opcode == Op::Variable);
530+
block.instructions.splice(index..index, insts);
509531
}
510532

511533
fn fuse_trivial_branches(function: &mut Function) {

0 commit comments

Comments
 (0)