Skip to content

Commit c048909

Browse files
committed
linker/inline: fuse trivial branches in less time
Originally, this algorithm walked a linked list by the back-edges, copying and skipping. It is easier to just go with front-edges and gobble up a series of potential blocks at once. The predecessor finding algorithm really just wanted to find 1-to-1 edges (it was split between `compute_all_preds` and `fuse_trivial_branches`), so made it that.
1 parent 41a6089 commit c048909

File tree

1 file changed

+60
-35
lines changed
  • crates/rustc_codegen_spirv/src/linker

1 file changed

+60
-35
lines changed

crates/rustc_codegen_spirv/src/linker/inline.rs

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -483,52 +483,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
483483
}
484484

485485
fn fuse_trivial_branches(function: &mut Function) {
486-
let all_preds = compute_preds(&function.blocks);
486+
let mut chain_list = compute_outgoing_1to1_branches(&function.blocks);
487487
let mut rewrite_rules = FxHashMap::default();
488-
'outer: for (dest_block, mut preds) in all_preds.iter().enumerate() {
489-
// if there's two trivial branches in a row, the middle one might get inlined before the
490-
// last one, so when processing the last one, skip through to the first one.
491-
let pred = loop {
492-
if preds.len() != 1 || preds[0] == dest_block {
493-
continue 'outer;
494-
}
495-
let pred = preds[0];
496-
if !function.blocks[pred].instructions.is_empty() {
497-
break pred;
498-
}
499-
preds = &all_preds[pred];
500-
};
501-
let pred_insts = &function.blocks[pred].instructions;
502-
if pred_insts.last().unwrap().class.opcode == Op::Branch {
503-
let mut dest_insts = take(&mut function.blocks[dest_block].instructions);
504-
dest_insts.retain(|inst| {
505-
if inst.class.opcode == Op::Phi {
506-
assert_eq!(inst.operands.len(), 2);
507-
rewrite_rules.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
508-
false
509-
} else {
510-
true
488+
489+
for block_idx in 0..chain_list.len() {
490+
let mut next = chain_list[block_idx].take();
491+
loop {
492+
match next {
493+
None => {
494+
// end of the chain list
495+
break;
511496
}
512-
});
513-
let pred_insts = &mut function.blocks[pred].instructions;
514-
pred_insts.pop(); // pop the branch
515-
pred_insts.append(&mut dest_insts);
497+
Some(x) if x == block_idx => {
498+
// loop detected
499+
break;
500+
}
501+
Some(next_idx) => {
502+
let mut dest_insts = take(&mut function.blocks[next_idx].instructions);
503+
dest_insts.retain(|inst| {
504+
if inst.class.opcode == Op::Phi {
505+
assert_eq!(inst.operands.len(), 2);
506+
rewrite_rules
507+
.insert(inst.result_id.unwrap(), inst.operands[0].unwrap_id_ref());
508+
false
509+
} else {
510+
true
511+
}
512+
});
513+
let self_insts = &mut function.blocks[block_idx].instructions;
514+
self_insts.pop(); // pop the branch
515+
self_insts.append(&mut dest_insts);
516+
next = chain_list[next_idx].take();
517+
}
518+
}
516519
}
517520
}
518521
function.blocks.retain(|b| !b.instructions.is_empty());
519522
apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
520523
}
521524

522-
fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
523-
let mut result = vec![vec![]; blocks.len()];
525+
fn compute_outgoing_1to1_branches(blocks: &[Block]) -> Vec<Option<usize>> {
526+
let block_id_to_idx: FxHashMap<_, _> = blocks
527+
.iter()
528+
.enumerate()
529+
.map(|(idx, block)| (block.label_id().unwrap(), idx))
530+
.collect();
531+
#[derive(Clone)]
532+
enum NumIncoming {
533+
Zero,
534+
One(usize),
535+
TooMany,
536+
}
537+
let mut incoming = vec![NumIncoming::Zero; blocks.len()];
524538
for (source_idx, source) in blocks.iter().enumerate() {
525539
for dest_id in outgoing_edges(source) {
526-
let dest_idx = blocks
527-
.iter()
528-
.position(|b| b.label_id().unwrap() == dest_id)
529-
.unwrap();
530-
result[dest_idx].push(source_idx);
540+
let dest_idx = block_id_to_idx[&dest_id];
541+
incoming[dest_idx] = match incoming[dest_idx] {
542+
NumIncoming::Zero => NumIncoming::One(source_idx),
543+
_ => NumIncoming::TooMany,
544+
}
545+
}
546+
}
547+
548+
let mut result = vec![None; blocks.len()];
549+
550+
for (dest_idx, inc) in incoming.iter().enumerate() {
551+
if let &NumIncoming::One(source_idx) = inc {
552+
if blocks[source_idx].instructions.last().unwrap().class.opcode == Op::Branch {
553+
result[source_idx] = Some(dest_idx);
554+
}
531555
}
532556
}
557+
533558
result
534559
}

0 commit comments

Comments
 (0)