Skip to content

Commit e161317

Browse files
committed
linker/inline: upgrade cycle counting to topo sort & resolve quadratic inlining.
By inlining in callee -> caller order, we avoid the need to continue inlining the code we just inlined. A simple reachability test from one of the entry points helps avoid unnecessary work as well. The algorithm however remains quadratic in case where OpAccessChains repeatedly find their way into function parameters. There are two ways out: either a more complex control flow analysis, or conservatively inlining all function calls which reference FunctionParameters as arguments. I don't think either case is very worth it.
1 parent 6232d95 commit e161317

File tree

1 file changed

+141
-116
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+141
-116
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 141 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,24 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
1313
use rustc_session::Session;
1414
use std::mem::take;
1515

16-
type FunctionMap = FxHashMap<Word, Function>;
16+
type FunctionMap = FxHashMap<Word, usize>;
1717

1818
pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
19+
let (disallowed_argument_types, disallowed_return_types) =
20+
compute_disallowed_argument_and_return_types(module);
21+
let mut to_delete: Vec<_> = module
22+
.functions
23+
.iter()
24+
.map(|f| should_inline(&disallowed_argument_types, &disallowed_return_types, f))
25+
.collect();
1926
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
20-
if module_has_recursion(sess, module) {
21-
return Err(rustc_errors::ErrorReported);
22-
}
27+
let postorder = compute_function_postorder(sess, module, &mut to_delete)?;
2328
let functions = module
2429
.functions
2530
.iter()
26-
.map(|f| (f.def_id().unwrap(), f.clone()))
31+
.enumerate()
32+
.map(|(idx, f)| (f.def_id().unwrap(), idx))
2733
.collect();
28-
let (disallowed_argument_types, disallowed_return_types) =
29-
compute_disallowed_argument_and_return_types(module);
3034
let void = module
3135
.types_global_values
3236
.iter()
@@ -35,23 +39,6 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
3539
.unwrap_or(0);
3640
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
3741
// inlines in functions that will get inlined)
38-
let mut dropped_ids = FxHashSet::default();
39-
module.functions.retain(|f| {
40-
if should_inline(&disallowed_argument_types, &disallowed_return_types, f) {
41-
// TODO: We should insert all defined IDs in this function.
42-
dropped_ids.insert(f.def_id().unwrap());
43-
false
44-
} else {
45-
true
46-
}
47-
});
48-
// Drop OpName etc. for inlined functions
49-
module.debug_names.retain(|inst| {
50-
!inst.operands.iter().any(|op| {
51-
op.id_ref_any()
52-
.map_or(false, |id| dropped_ids.contains(&id))
53-
})
54-
});
5542
let mut inliner = Inliner {
5643
header: module.header.as_mut().unwrap(),
5744
types_global_values: &mut module.types_global_values,
@@ -60,77 +47,122 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
6047
disallowed_argument_types: &disallowed_argument_types,
6148
disallowed_return_types: &disallowed_return_types,
6249
};
63-
for function in &mut module.functions {
64-
inliner.inline_fn(function);
65-
fuse_trivial_branches(function);
50+
for index in postorder {
51+
inliner.inline_fn(&mut module.functions, index);
52+
fuse_trivial_branches(&mut module.functions[index]);
6653
}
54+
let mut dropped_ids = FxHashSet::default();
55+
for i in (0..module.functions.len()).rev() {
56+
if to_delete[i] {
57+
dropped_ids.insert(module.functions.remove(i).def_id().unwrap());
58+
}
59+
}
60+
// Drop OpName etc. for inlined functions
61+
module.debug_names.retain(|inst| {
62+
!inst.operands.iter().any(|op| {
63+
op.id_ref_any()
64+
.map_or(false, |id| dropped_ids.contains(&id))
65+
})
66+
});
6767
Ok(())
6868
}
6969

70-
// https://stackoverflow.com/a/53995651
71-
fn module_has_recursion(sess: &Session, module: &Module) -> bool {
70+
/// Topological sorting algorithm due to T. Cormen
71+
/// Starts from module's entry points, so only reachable functions will be returned
72+
/// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`,
73+
/// `to_delete[i]` is set to true.
74+
fn compute_function_postorder(
75+
sess: &Session,
76+
module: &Module,
77+
to_delete: &mut [bool],
78+
) -> super::Result<Vec<usize>> {
7279
let func_to_index: FxHashMap<Word, usize> = module
7380
.functions
7481
.iter()
7582
.enumerate()
7683
.map(|(index, func)| (func.def_id().unwrap(), index))
7784
.collect();
78-
let mut discovered = vec![false; module.functions.len()];
79-
let mut finished = vec![false; module.functions.len()];
85+
/// Possible node states for cycle-discovering DFS.
86+
#[derive(Clone, PartialEq)]
87+
enum NodeState {
88+
/// Normal, not visited.
89+
NotVisited,
90+
/// Currently being visited.
91+
Discovered,
92+
/// DFS returned.
93+
Finished,
94+
/// Not visited, entry point.
95+
Entry,
96+
}
97+
let mut states = vec![NodeState::NotVisited; module.functions.len()];
98+
for opep in module.entry_points.iter() {
99+
let func_id = opep.operands[1].unwrap_id_ref();
100+
states[func_to_index[&func_id]] = NodeState::Entry;
101+
}
80102
let mut has_recursion = false;
103+
let mut postorder = vec![];
81104
for index in 0..module.functions.len() {
82-
if !discovered[index] && !finished[index] {
105+
if NodeState::Entry == states[index] {
83106
visit(
84107
sess,
85108
module,
86109
index,
87-
&mut discovered,
88-
&mut finished,
110+
&mut states[..],
89111
&mut has_recursion,
112+
&mut postorder,
90113
&func_to_index,
91114
);
92115
}
93116
}
94117

118+
for index in 0..module.functions.len() {
119+
if NodeState::NotVisited == states[index] {
120+
to_delete[index] = true;
121+
}
122+
}
123+
95124
fn visit(
96125
sess: &Session,
97126
module: &Module,
98127
current: usize,
99-
discovered: &mut Vec<bool>,
100-
finished: &mut Vec<bool>,
128+
states: &mut [NodeState],
101129
has_recursion: &mut bool,
130+
postorder: &mut Vec<usize>,
102131
func_to_index: &FxHashMap<Word, usize>,
103132
) {
104-
discovered[current] = true;
133+
states[current] = NodeState::Discovered;
105134

106135
for next in calls(&module.functions[current], func_to_index) {
107-
if discovered[next] {
108-
let names = get_names(module);
109-
let current_name = get_name(&names, module.functions[current].def_id().unwrap());
110-
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
111-
sess.err(&format!(
112-
"module has recursion, which is not allowed: `{}` calls `{}`",
113-
current_name, next_name
114-
));
115-
*has_recursion = true;
116-
break;
117-
}
118-
119-
if !finished[next] {
120-
visit(
121-
sess,
122-
module,
123-
next,
124-
discovered,
125-
finished,
126-
has_recursion,
127-
func_to_index,
128-
);
136+
match states[next] {
137+
NodeState::Discovered => {
138+
let names = get_names(module);
139+
let current_name =
140+
get_name(&names, module.functions[current].def_id().unwrap());
141+
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
142+
sess.err(&format!(
143+
"module has recursion, which is not allowed: `{}` calls `{}`",
144+
current_name, next_name
145+
));
146+
*has_recursion = true;
147+
break;
148+
}
149+
NodeState::NotVisited | NodeState::Entry => {
150+
visit(
151+
sess,
152+
module,
153+
next,
154+
states,
155+
has_recursion,
156+
postorder,
157+
func_to_index,
158+
);
159+
}
160+
NodeState::Finished => {}
129161
}
130162
}
131163

132-
discovered[current] = false;
133-
finished[current] = true;
164+
states[current] = NodeState::Finished;
165+
postorder.push(current)
134166
}
135167

136168
fn calls<'a>(
@@ -146,7 +178,11 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool {
146178
})
147179
}
148180

149-
has_recursion
181+
if has_recursion {
182+
Err(rustc_errors::ErrorReported)
183+
} else {
184+
Ok(postorder)
185+
}
150186
}
151187

152188
fn compute_disallowed_argument_and_return_types(
@@ -283,33 +319,39 @@ impl Inliner<'_, '_> {
283319
inst_id
284320
}
285321

286-
fn inline_fn(&mut self, function: &mut Function) {
322+
fn inline_fn(&mut self, functions: &mut [Function], index: usize) {
323+
let mut function = take(&mut functions[index]);
287324
let mut block_idx = 0;
288325
while block_idx < function.blocks.len() {
289-
// If we successfully inlined a block, then repeat processing on the same block, in
290-
// case the newly inlined block has more inlined calls.
291-
// TODO: This is quadratic
292-
if !self.inline_block(function, block_idx) {
293-
block_idx += 1;
294-
}
326+
// If we successfully inlined a block, then continue processing on the next block or its tail.
327+
// TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments.
328+
// For the common case of "we knew which functions to inline", it is linear.
329+
self.inline_block(&mut function, &functions, block_idx);
330+
block_idx += 1;
295331
}
332+
functions[index] = function;
296333
}
297334

298-
fn inline_block(&mut self, caller: &mut Function, block_idx: usize) -> bool {
335+
/// Inlines one block and returns whether inlining actually occurred.
336+
/// After calling this, blocks[block_idx] is finished processing.
337+
fn inline_block(
338+
&mut self,
339+
caller: &mut Function,
340+
functions: &[Function],
341+
block_idx: usize,
342+
) -> bool {
299343
// Find the first inlined OpFunctionCall
300344
let call = caller.blocks[block_idx]
301345
.instructions
302346
.iter()
303347
.enumerate()
304348
.filter(|(_, inst)| inst.class.opcode == Op::FunctionCall)
305349
.map(|(index, inst)| {
306-
(
307-
index,
308-
inst,
309-
self.functions
310-
.get(&inst.operands[0].id_ref_any().unwrap())
311-
.unwrap(),
312-
)
350+
let idx = self
351+
.functions
352+
.get(&inst.operands[0].id_ref_any().unwrap())
353+
.unwrap();
354+
(index, inst, &functions[*idx])
313355
})
314356
.find(|(_, inst, f)| {
315357
should_inline(
@@ -374,17 +416,23 @@ impl Inliner<'_, '_> {
374416
);
375417
}
376418

377-
// Fuse the first block of the callee into the block of the caller. This is okay because
378-
// it's illegal to branch to the first BB in a function.
379-
let mut callee_header = inlined_blocks.remove(0).instructions;
419+
// Move the variables over from the inlined function to here.
420+
let mut callee_header = take(&mut inlined_blocks[0]).instructions;
380421
// TODO: OpLine handling
381-
let num_variables = callee_header
382-
.iter()
383-
.position(|inst| inst.class.opcode != Op::Variable)
384-
.unwrap_or(callee_header.len());
385-
caller.blocks[block_idx]
386-
.instructions
387-
.append(&mut callee_header.split_off(num_variables));
422+
let num_variables = callee_header.partition_point(|inst| inst.class.opcode == Op::Variable);
423+
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
424+
// it, and we maintain the invariant that current block has finished processing.
425+
let jump_to = self.id();
426+
inlined_blocks[0] = Block {
427+
label: Some(Instruction::new(Op::Label, None, Some(jump_to), vec![])),
428+
instructions: callee_header.split_off(num_variables),
429+
};
430+
caller.blocks[block_idx].instructions.push(Instruction::new(
431+
Op::Branch,
432+
None,
433+
None,
434+
vec![Operand::IdRef(jump_to)],
435+
));
388436
// Move the OpVariables of the callee to the caller.
389437
insert_opvariables(&mut caller.blocks[0], callee_header);
390438

@@ -466,45 +514,22 @@ fn get_inlined_blocks(
466514
fn insert_opvariable(block: &mut Block, ptr_ty: Word, result_id: Word) {
467515
let index = block
468516
.instructions
469-
.iter()
470-
.enumerate()
471-
.find_map(|(index, inst)| {
472-
if inst.class.opcode != Op::Variable {
473-
Some(index)
474-
} else {
475-
None
476-
}
477-
});
517+
.partition_point(|inst| inst.class.opcode == Op::Variable);
518+
478519
let inst = Instruction::new(
479520
Op::Variable,
480521
Some(ptr_ty),
481522
Some(result_id),
482523
vec![Operand::StorageClass(StorageClass::Function)],
483524
);
484-
match index {
485-
Some(index) => block.instructions.insert(index, inst),
486-
None => block.instructions.push(inst),
487-
}
525+
block.instructions.insert(index, inst)
488526
}
489527

490-
fn insert_opvariables(block: &mut Block, mut insts: Vec<Instruction>) {
528+
fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
491529
let index = block
492530
.instructions
493-
.iter()
494-
.enumerate()
495-
.find_map(|(index, inst)| {
496-
if inst.class.opcode != Op::Variable {
497-
Some(index)
498-
} else {
499-
None
500-
}
501-
});
502-
match index {
503-
Some(index) => {
504-
block.instructions.splice(index..index, insts);
505-
}
506-
None => block.instructions.append(&mut insts),
507-
}
531+
.partition_point(|inst| inst.class.opcode == Op::Variable);
532+
block.instructions.splice(index..index, insts);
508533
}
509534

510535
fn fuse_trivial_branches(function: &mut Function) {

0 commit comments

Comments
 (0)