@@ -485,52 +485,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
485
485
}
486
486
487
487
fn fuse_trivial_branches ( function : & mut Function ) {
488
- let all_preds = compute_preds ( & function. blocks ) ;
488
+ let mut chain_list = compute_outgoing_1to1_branches ( & function. blocks ) ;
489
489
let mut rewrite_rules = FxHashMap :: default ( ) ;
490
- ' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
491
- // if there's two trivial branches in a row, the middle one might get inlined before the
492
- // last one, so when processing the last one, skip through to the first one.
493
- let pred = loop {
494
- if preds. len ( ) != 1 || preds[ 0 ] == dest_block {
495
- continue ' outer;
496
- }
497
- let pred = preds[ 0 ] ;
498
- if !function. blocks [ pred] . instructions . is_empty ( ) {
499
- break pred;
500
- }
501
- preds = & all_preds[ pred] ;
502
- } ;
503
- let pred_insts = & function. blocks [ pred] . instructions ;
504
- if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
505
- let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
506
- dest_insts. retain ( |inst| {
507
- if inst. class . opcode == Op :: Phi {
508
- assert_eq ! ( inst. operands. len( ) , 2 ) ;
509
- rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
510
- false
511
- } else {
512
- true
490
+
491
+ for block_idx in 0 ..chain_list. len ( ) {
492
+ let mut next = chain_list[ block_idx] . take ( ) ;
493
+ loop {
494
+ match next {
495
+ None => {
496
+ // end of the chain list
497
+ break ;
513
498
}
514
- } ) ;
515
- let pred_insts = & mut function. blocks [ pred] . instructions ;
516
- pred_insts. pop ( ) ; // pop the branch
517
- pred_insts. append ( & mut dest_insts) ;
499
+ Some ( x) if x == block_idx => {
500
+ // loop detected
501
+ break ;
502
+ }
503
+ Some ( next_idx) => {
504
+ let mut dest_insts = take ( & mut function. blocks [ next_idx] . instructions ) ;
505
+ dest_insts. retain ( |inst| {
506
+ if inst. class . opcode == Op :: Phi {
507
+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
508
+ rewrite_rules
509
+ . insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
510
+ false
511
+ } else {
512
+ true
513
+ }
514
+ } ) ;
515
+ let self_insts = & mut function. blocks [ block_idx] . instructions ;
516
+ self_insts. pop ( ) ; // pop the branch
517
+ self_insts. append ( & mut dest_insts) ;
518
+ next = chain_list[ next_idx] . take ( ) ;
519
+ }
520
+ }
518
521
}
519
522
}
520
523
function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
521
524
apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
522
525
}
523
526
524
- fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
525
- let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
527
+ fn compute_outgoing_1to1_branches ( blocks : & [ Block ] ) -> Vec < Option < usize > > {
528
+ let block_id_to_idx: FxHashMap < _ , _ > = blocks
529
+ . iter ( )
530
+ . enumerate ( )
531
+ . map ( |( idx, block) | ( block. label_id ( ) . unwrap ( ) , idx) )
532
+ . collect ( ) ;
533
+ #[ derive( Clone ) ]
534
+ enum NumIncoming {
535
+ Zero ,
536
+ One ( usize ) ,
537
+ TooMany ,
538
+ }
539
+ let mut incoming = vec ! [ NumIncoming :: Zero ; blocks. len( ) ] ;
526
540
for ( source_idx, source) in blocks. iter ( ) . enumerate ( ) {
527
541
for dest_id in outgoing_edges ( source) {
528
- let dest_idx = blocks
529
- . iter ( )
530
- . position ( |b| b. label_id ( ) . unwrap ( ) == dest_id)
531
- . unwrap ( ) ;
532
- result[ dest_idx] . push ( source_idx) ;
542
+ let dest_idx = block_id_to_idx[ & dest_id] ;
543
+ incoming[ dest_idx] = match incoming[ dest_idx] {
544
+ NumIncoming :: Zero => NumIncoming :: One ( source_idx) ,
545
+ _ => NumIncoming :: TooMany ,
546
+ }
547
+ }
548
+ }
549
+
550
+ let mut result = vec ! [ None ; blocks. len( ) ] ;
551
+
552
+ for ( dest_idx, inc) in incoming. iter ( ) . enumerate ( ) {
553
+ if let & NumIncoming :: One ( source_idx) = inc {
554
+ if blocks[ source_idx] . instructions . last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
555
+ result[ source_idx] = Some ( dest_idx) ;
556
+ }
533
557
}
534
558
}
559
+
535
560
result
536
561
}
0 commit comments