@@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
37
37
. iter ( )
38
38
. find ( |inst| inst. class . opcode == Op :: TypeVoid )
39
39
. map_or ( 0 , |inst| inst. result_id . unwrap ( ) ) ;
40
- let ptr_map: FxHashMap < _ , _ > = module
41
- . types_global_values
42
- . iter ( )
43
- . filter_map ( |inst| {
44
- if inst. class . opcode == Op :: TypePointer
45
- && inst. operands [ 0 ] . unwrap_storage_class ( ) == StorageClass :: Function
46
- {
47
- Some ( ( inst. operands [ 1 ] . unwrap_id_ref ( ) , inst. result_id . unwrap ( ) ) )
48
- } else {
49
- None
50
- }
51
- } )
52
- . collect ( ) ;
40
+
41
+ let invalid_args = module. functions . iter ( ) . flat_map ( get_invalid_args) . collect ( ) ;
42
+
53
43
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
54
44
// inlines in functions that will get inlined)
55
45
let mut inliner = Inliner {
56
46
header : module. header . as_mut ( ) . unwrap ( ) ,
57
- types_global_values : & mut module. types_global_values ,
58
47
void,
59
- ptr_map,
60
48
functions : & functions,
61
49
needs_inline : & to_delete,
50
+ invalid_args,
62
51
} ;
63
52
for index in postorder {
64
53
inliner. inline_fn ( & mut module. functions , index) ;
@@ -268,20 +257,21 @@ fn should_inline(
268
257
// This should be more general, but a very common problem is passing an OpAccessChain to an
269
258
// OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect
270
259
// that case and inline the call.
271
- fn args_invalid ( function : & Function , call : & Instruction ) -> bool {
272
- for inst in function. all_inst_iter ( ) {
260
+ fn get_invalid_args < ' a > ( function : & ' a Function ) -> impl Iterator < Item = Word > + ' a {
261
+ function. all_inst_iter ( ) . filter_map ( |inst| {
273
262
if inst. class . opcode == Op :: AccessChain {
274
- let inst_result = inst. result_id . unwrap ( ) ;
275
- if call
276
- . operands
277
- . iter ( )
278
- . any ( |op| * op == Operand :: IdRef ( inst_result) )
279
- {
280
- return true ;
281
- }
263
+ inst. result_id
264
+ } else {
265
+ None
282
266
}
283
- }
284
- false
267
+ } )
268
+ }
269
+
270
+ fn args_invalid ( invalid_args : & FxHashSet < Word > , call : & Instruction ) -> bool {
271
+ call. operands . iter ( ) . skip ( 1 ) . any ( |op| {
272
+ op. id_ref_any ( )
273
+ . map_or ( false , |arg| invalid_args. contains ( & arg) )
274
+ } )
285
275
}
286
276
287
277
// Steps:
@@ -292,11 +282,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool {
292
282
293
283
struct Inliner < ' m , ' map > {
294
284
header : & ' m mut ModuleHeader ,
295
- types_global_values : & ' m mut Vec < Instruction > ,
296
285
void : Word ,
297
- ptr_map : FxHashMap < Word , Word > ,
298
286
functions : & ' map FunctionMap ,
299
287
needs_inline : & ' map [ bool ] ,
288
+ invalid_args : FxHashSet < Word > ,
300
289
}
301
290
302
291
impl Inliner < ' _ , ' _ > {
@@ -306,25 +295,6 @@ impl Inliner<'_, '_> {
306
295
result
307
296
}
308
297
309
- fn ptr_ty ( & mut self , pointee : Word ) -> Word {
310
- let existing = self . ptr_map . get ( & pointee) ;
311
- if let Some ( existing) = existing {
312
- return * existing;
313
- }
314
- let inst_id = self . id ( ) ;
315
- self . types_global_values . push ( Instruction :: new (
316
- Op :: TypePointer ,
317
- None ,
318
- Some ( inst_id) ,
319
- vec ! [
320
- Operand :: StorageClass ( StorageClass :: Function ) ,
321
- Operand :: IdRef ( pointee) ,
322
- ] ,
323
- ) ) ;
324
- self . ptr_map . insert ( pointee, inst_id) ;
325
- inst_id
326
- }
327
-
328
298
fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
329
299
let mut function = take ( & mut functions[ index] ) ;
330
300
let mut block_idx = 0 ;
@@ -359,8 +329,8 @@ impl Inliner<'_, '_> {
359
329
self . functions [ & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ] ,
360
330
)
361
331
} )
362
- . find ( |( index , inst, func_idx) | {
363
- self . needs_inline [ * func_idx] || args_invalid ( caller , inst)
332
+ . find ( |( _ , inst, func_idx) | {
333
+ self . needs_inline [ * func_idx] || args_invalid ( & self . invalid_args , inst)
364
334
} ) ;
365
335
let ( call_index, call_inst, callee_idx) = match call {
366
336
None => return false ,
@@ -388,18 +358,23 @@ impl Inliner<'_, '_> {
388
358
} ) ;
389
359
let mut rewrite_rules = callee_parameters. zip ( call_arguments) . collect ( ) ;
390
360
391
- let return_variable = if call_result_type. is_some ( ) {
392
- Some ( self . id ( ) )
393
- } else {
394
- None
395
- } ;
396
361
let return_jump = self . id ( ) ;
397
362
// Rewrite OpReturns of the callee.
398
- let mut inlined_blocks = get_inlined_blocks ( callee, return_variable , return_jump) ;
363
+ let ( mut inlined_blocks, phi_pairs ) = get_inlined_blocks ( callee, return_jump) ;
399
364
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
400
365
// fn is inlined multiple times.
401
366
self . add_clone_id_rules ( & mut rewrite_rules, & inlined_blocks) ;
367
+ // If any of the OpReturns were invalid, return will also be invalid.
368
+ for ( value, _) in & phi_pairs {
369
+ if self . invalid_args . contains ( value) {
370
+ self . invalid_args . insert ( call_result_id) ;
371
+ self . invalid_args
372
+ . insert ( * rewrite_rules. get ( value) . unwrap_or ( value) ) ;
373
+ }
374
+ }
402
375
apply_rewrite_rules ( & rewrite_rules, & mut inlined_blocks) ;
376
+ // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
377
+ // as no values from inside the inlined function ever make it directly out.
403
378
404
379
// Split the block containing the OpFunctionCall into two, around the call.
405
380
let mut post_call_block_insts = caller. blocks [ block_idx]
@@ -409,32 +384,27 @@ impl Inliner<'_, '_> {
409
384
let call = caller. blocks [ block_idx] . instructions . pop ( ) . unwrap ( ) ;
410
385
assert ! ( call. class. opcode == Op :: FunctionCall ) ;
411
386
412
- if let Some ( call_result_type) = call_result_type {
413
- // Generate the storage space for the return value: Do this *after* the split above,
414
- // because if block_idx=0, inserting a variable here shifts call_index.
415
- insert_opvariable (
416
- & mut caller. blocks [ 0 ] ,
417
- self . ptr_ty ( call_result_type) ,
418
- return_variable. unwrap ( ) ,
419
- ) ;
420
- }
421
-
422
387
// Move the variables over from the inlined function to here.
423
388
let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
424
389
// TODO: OpLine handling
425
390
let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
426
391
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
427
392
// it, and we maintain the invariant that current block has finished processing.
428
- let jump_to = self . id ( ) ;
393
+ let first_block_id = self . id ( ) ;
429
394
inlined_blocks[ 0 ] = Block {
430
- label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
395
+ label : Some ( Instruction :: new (
396
+ Op :: Label ,
397
+ None ,
398
+ Some ( first_block_id) ,
399
+ vec ! [ ] ,
400
+ ) ) ,
431
401
instructions : callee_header. split_off ( num_variables) ,
432
402
} ;
433
403
caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
434
404
Op :: Branch ,
435
405
None ,
436
406
None ,
437
- vec ! [ Operand :: IdRef ( jump_to ) ] ,
407
+ vec ! [ Operand :: IdRef ( first_block_id ) ] ,
438
408
) ) ;
439
409
// Move the OpVariables of the callee to the caller.
440
410
insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
@@ -445,10 +415,17 @@ impl Inliner<'_, '_> {
445
415
post_call_block_insts. insert (
446
416
0 ,
447
417
Instruction :: new (
448
- Op :: Load ,
418
+ Op :: Phi ,
449
419
Some ( call_result_type) ,
450
420
Some ( call_result_id) ,
451
- vec ! [ Operand :: IdRef ( return_variable. unwrap( ) ) ] ,
421
+ phi_pairs
422
+ . into_iter ( )
423
+ . flat_map ( |( value, parent) | {
424
+ use std:: iter;
425
+ iter:: once ( Operand :: IdRef ( * rewrite_rules. get ( & value) . unwrap_or ( & value) ) )
426
+ . chain ( iter:: once ( Operand :: IdRef ( rewrite_rules[ & parent] ) ) )
427
+ } )
428
+ . collect ( ) ,
452
429
) ,
453
430
) ;
454
431
}
@@ -481,51 +458,21 @@ impl Inliner<'_, '_> {
481
458
}
482
459
}
483
460
484
- fn get_inlined_blocks (
485
- function : & Function ,
486
- return_variable : Option < Word > ,
487
- return_jump : Word ,
488
- ) -> Vec < Block > {
461
+ fn get_inlined_blocks ( function : & Function , return_jump : Word ) -> ( Vec < Block > , Vec < ( Word , Word ) > ) {
489
462
let mut blocks = function. blocks . clone ( ) ;
463
+ let mut phipairs = Vec :: new ( ) ;
490
464
for block in & mut blocks {
491
465
let last = block. instructions . last ( ) . unwrap ( ) ;
492
466
if let Op :: Return | Op :: ReturnValue = last. class . opcode {
493
467
if Op :: ReturnValue == last. class . opcode {
494
468
let return_value = last. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ;
495
- block. instructions . insert (
496
- block. instructions . len ( ) - 1 ,
497
- Instruction :: new (
498
- Op :: Store ,
499
- None ,
500
- None ,
501
- vec ! [
502
- Operand :: IdRef ( return_variable. unwrap( ) ) ,
503
- Operand :: IdRef ( return_value) ,
504
- ] ,
505
- ) ,
506
- ) ;
507
- } else {
508
- assert ! ( return_variable. is_none( ) ) ;
469
+ phipairs. push ( ( return_value, block. label_id ( ) . unwrap ( ) ) )
509
470
}
510
471
* block. instructions . last_mut ( ) . unwrap ( ) =
511
472
Instruction :: new ( Op :: Branch , None , None , vec ! [ Operand :: IdRef ( return_jump) ] ) ;
512
473
}
513
474
}
514
- blocks
515
- }
516
-
517
- fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
518
- let index = block
519
- . instructions
520
- . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
521
-
522
- let inst = Instruction :: new (
523
- Op :: Variable ,
524
- Some ( ptr_ty) ,
525
- Some ( result_id) ,
526
- vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
527
- ) ;
528
- block. instructions . insert ( index, inst)
475
+ ( blocks, phipairs)
529
476
}
530
477
531
478
fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
@@ -537,6 +484,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
537
484
538
485
fn fuse_trivial_branches ( function : & mut Function ) {
539
486
let all_preds = compute_preds ( & function. blocks ) ;
487
+ let mut rewrite_rules = FxHashMap :: default ( ) ;
540
488
' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
541
489
// if there's two trivial branches in a row, the middle one might get inlined before the
542
490
// last one, so when processing the last one, skip through to the first one.
@@ -553,12 +501,22 @@ fn fuse_trivial_branches(function: &mut Function) {
553
501
let pred_insts = & function. blocks [ pred] . instructions ;
554
502
if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
555
503
let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
504
+ dest_insts. retain ( |inst| {
505
+ if inst. class . opcode == Op :: Phi {
506
+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
507
+ rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
508
+ false
509
+ } else {
510
+ true
511
+ }
512
+ } ) ;
556
513
let pred_insts = & mut function. blocks [ pred] . instructions ;
557
514
pred_insts. pop ( ) ; // pop the branch
558
515
pred_insts. append ( & mut dest_insts) ;
559
516
}
560
517
}
561
518
function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
519
+ apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
562
520
}
563
521
564
522
fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
0 commit comments