@@ -483,52 +483,77 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
483
483
}
484
484
485
485
fn fuse_trivial_branches ( function : & mut Function ) {
486
- let all_preds = compute_preds ( & function. blocks ) ;
486
+ let mut chain_list = compute_outgoing_1to1_branches ( & function. blocks ) ;
487
487
let mut rewrite_rules = FxHashMap :: default ( ) ;
488
- ' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
489
- // if there's two trivial branches in a row, the middle one might get inlined before the
490
- // last one, so when processing the last one, skip through to the first one.
491
- let pred = loop {
492
- if preds. len ( ) != 1 || preds[ 0 ] == dest_block {
493
- continue ' outer;
494
- }
495
- let pred = preds[ 0 ] ;
496
- if !function. blocks [ pred] . instructions . is_empty ( ) {
497
- break pred;
498
- }
499
- preds = & all_preds[ pred] ;
500
- } ;
501
- let pred_insts = & function. blocks [ pred] . instructions ;
502
- if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
503
- let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
504
- dest_insts. retain ( |inst| {
505
- if inst. class . opcode == Op :: Phi {
506
- assert_eq ! ( inst. operands. len( ) , 2 ) ;
507
- rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
508
- false
509
- } else {
510
- true
488
+
489
+ for block_idx in 0 ..chain_list. len ( ) {
490
+ let mut next = chain_list[ block_idx] . take ( ) ;
491
+ loop {
492
+ match next {
493
+ None => {
494
+ // end of the chain list
495
+ break ;
511
496
}
512
- } ) ;
513
- let pred_insts = & mut function. blocks [ pred] . instructions ;
514
- pred_insts. pop ( ) ; // pop the branch
515
- pred_insts. append ( & mut dest_insts) ;
497
+ Some ( x) if x == block_idx => {
498
+ // loop detected
499
+ break ;
500
+ }
501
+ Some ( next_idx) => {
502
+ let mut dest_insts = take ( & mut function. blocks [ next_idx] . instructions ) ;
503
+ dest_insts. retain ( |inst| {
504
+ if inst. class . opcode == Op :: Phi {
505
+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
506
+ rewrite_rules
507
+ . insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
508
+ false
509
+ } else {
510
+ true
511
+ }
512
+ } ) ;
513
+ let self_insts = & mut function. blocks [ block_idx] . instructions ;
514
+ self_insts. pop ( ) ; // pop the branch
515
+ self_insts. append ( & mut dest_insts) ;
516
+ next = chain_list[ next_idx] . take ( ) ;
517
+ }
518
+ }
516
519
}
517
520
}
518
521
function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
519
522
apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
520
523
}
521
524
522
- fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
523
- let mut result = vec ! [ vec![ ] ; blocks. len( ) ] ;
525
+ fn compute_outgoing_1to1_branches ( blocks : & [ Block ] ) -> Vec < Option < usize > > {
526
+ let block_id_to_idx: FxHashMap < _ , _ > = blocks
527
+ . iter ( )
528
+ . enumerate ( )
529
+ . map ( |( idx, block) | ( block. label_id ( ) . unwrap ( ) , idx) )
530
+ . collect ( ) ;
531
+ #[ derive( Clone ) ]
532
+ enum NumIncoming {
533
+ Zero ,
534
+ One ( usize ) ,
535
+ TooMany ,
536
+ }
537
+ let mut incoming = vec ! [ NumIncoming :: Zero ; blocks. len( ) ] ;
524
538
for ( source_idx, source) in blocks. iter ( ) . enumerate ( ) {
525
539
for dest_id in outgoing_edges ( source) {
526
- let dest_idx = blocks
527
- . iter ( )
528
- . position ( |b| b. label_id ( ) . unwrap ( ) == dest_id)
529
- . unwrap ( ) ;
530
- result[ dest_idx] . push ( source_idx) ;
540
+ let dest_idx = block_id_to_idx[ & dest_id] ;
541
+ incoming[ dest_idx] = match incoming[ dest_idx] {
542
+ NumIncoming :: Zero => NumIncoming :: One ( source_idx) ,
543
+ _ => NumIncoming :: TooMany ,
544
+ }
545
+ }
546
+ }
547
+
548
+ let mut result = vec ! [ None ; blocks. len( ) ] ;
549
+
550
+ for ( dest_idx, inc) in incoming. iter ( ) . enumerate ( ) {
551
+ if let & NumIncoming :: One ( source_idx) = inc {
552
+ if blocks[ source_idx] . instructions . last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
553
+ result[ source_idx] = Some ( dest_idx) ;
554
+ }
531
555
}
532
556
}
557
+
533
558
result
534
559
}
0 commit comments