Skip to content

Commit 381ad53

Browse files
committed
linker/inline: fuse trivial branches in less time
Originally, this algorithm walked a linked list by the back-edges, copying and skipping. It is easier to just go with front-edges and gobble up a series of potential blocks at once. The predecessor finding algorithm really just wanted to find 1-to-1 edges (it was split between `compute_all_preds` and `fuse_trivial_branches`), so made it that.
1 parent 990425b commit 381ad53

File tree

1 file changed

+60
-35
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+60
-35
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -485,52 +485,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
485485
}
486486

487487
fn fuse_trivial_branches(function: &mut Function) {
488-
let all_preds = compute_preds(&function.blocks);
488+
let mut chain_list = compute_outgoing_1to1_branches(&function.blocks);
489489
let mut rewrite_rules = FxHashMap::default();
490-
'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
491-
// if there's two trivial branches in a row, the middle one might get inlined before the
492-
// last one, so when processing the last one, skip through to the first one.
493-
let pred = loop {
494-
if preds.len() != 1 || preds[0] == dest_block {
495-
continue 'outer;
496-
}
497-
let pred = preds[0];
498-
if !function.blocks[pred].instructions.is_empty() {
499-
break pred;
500-
}
501-
preds = &all_preds[pred];
502-
};
503-
let pred_insts = &function.blocks[pred].instructions;
504-
if pred_insts.last().unwrap().class.opcode == Op::Branch {
505-
let mut dest_insts = take(&mut function.blocks[dest_block].instructions);
506-
dest_insts.retain(|inst| {
507-
if inst.class.opcode == Op::Phi {
508-
assert_eq!(inst.operands.len(), 2);
509-
rewrite_rules.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
510-
false
511-
} else {
512-
true
490+
491+
for block_idx in 0..chain_list.len() {
492+
let mut next = chain_list[block_idx].take();
493+
loop {
494+
match next {
495+
None => {
496+
// end of the chain list
497+
break;
513498
}
514-
});
515-
let pred_insts = &mut function.blocks[pred].instructions;
516-
pred_insts.pop(); // pop the branch
517-
pred_insts.append(&mut dest_insts);
499+
Some(x) if x == block_idx => {
500+
// loop detected
501+
break;
502+
}
503+
Some(next_idx) => {
504+
let mut dest_insts = take(&mut function.blocks[next_idx].instructions);
505+
dest_insts.retain(|inst| {
506+
if inst.class.opcode == Op::Phi {
507+
assert_eq!(inst.operands.len(), 2);
508+
rewrite_rules
509+
.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
510+
false
511+
} else {
512+
true
513+
}
514+
});
515+
let self_insts = &mut function.blocks[block_idx].instructions;
516+
self_insts.pop(); // pop the branch
517+
self_insts.append(&mut dest_insts);
518+
next = chain_list[next_idx].take();
519+
}
520+
}
518521
}
519522
}
520523
function.blocks.retain(|b| !b.instructions.is_empty());
521524
apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
522525
}
523526

524-
fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
525-
let mut result = vec![vec![]; blocks.len()];
527+
fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec<Option<usize>> {
528+
let block_id_to_idx: FxHashMap<_, _> = blocks
529+
.iter()
530+
.enumerate()
531+
.map(|(idx, block)| (block.label_id().unwrap(), idx))
532+
.collect();
533+
#[derive(Clone)]
534+
enum NumIncoming {
535+
Zero,
536+
One(usize),
537+
TooMany,
538+
}
539+
let mut incoming = vec![NumIncoming::Zero; blocks.len()];
526540
for (source_idx, source) in blocks.iter().enumerate() {
527541
for dest_id in outgoing_edges(source) {
528-
let dest_idx = blocks
529-
.iter()
530-
.position(|b| b.label_id().unwrap() == dest_id)
531-
.unwrap();
532-
result[dest_idx].push(source_idx);
542+
let dest_idx = block_id_to_idx[&dest_id];
543+
incoming[dest_idx] = match incoming[dest_idx] {
544+
NumIncoming::Zero => NumIncoming::One(source_idx),
545+
_ => NumIncoming::TooMany,
546+
}
547+
}
548+
}
549+
550+
let mut result = vec![None; blocks.len()];
551+
552+
for (dest_idx, inc) in incoming.iter().enumerate() {
553+
if let &NumIncoming::One(source_idx) = inc {
554+
if blocks[source_idx].instructions.last().unwrap().class.opcode == Op::Branch {
555+
result[source_idx] = Some(dest_idx);
556+
}
533557
}
534558
}
559+
535560
result
536561
}

0 commit comments

Comments
 (0)