@@ -37,28 +37,17 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
37
37
. find ( |inst| inst. class . opcode == Op :: TypeVoid )
38
38
. map ( |inst| inst. result_id . unwrap ( ) )
39
39
. unwrap_or ( 0 ) ;
40
- let ptr_map: FxHashMap < _ , _ > = module
41
- . types_global_values
42
- . iter ( )
43
- . filter_map ( |inst| {
44
- if inst. class . opcode == Op :: TypePointer
45
- && inst. operands [ 0 ] . unwrap_storage_class ( ) == StorageClass :: Function
46
- {
47
- Some ( ( inst. operands [ 1 ] . unwrap_id_ref ( ) , inst. result_id . unwrap ( ) ) )
48
- } else {
49
- None
50
- }
51
- } )
52
- . collect ( ) ;
40
+
41
+ let invalid_args = module. functions . iter ( ) . flat_map ( get_invalid_args) . collect ( ) ;
42
+
53
43
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
54
44
// inlines in functions that will get inlined)
55
45
let mut inliner = Inliner {
56
46
header : module. header . as_mut ( ) . unwrap ( ) ,
57
- types_global_values : & mut module. types_global_values ,
58
47
void,
59
- ptr_map,
60
48
functions : & functions,
61
49
needs_inline : & to_delete,
50
+ invalid_args,
62
51
} ;
63
52
for index in postorder {
64
53
inliner. inline_fn ( & mut module. functions , index) ;
@@ -270,20 +259,21 @@ fn should_inline(
270
259
// This should be more general, but a very common problem is passing an OpAccessChain to an
271
260
// OpFunctionCall (i.e. `f(&s.x)`, or more commonly, `s.x.f()` where `f` takes `&self`), so detect
272
261
// that case and inline the call.
273
- fn args_invalid ( function : & Function , call : & Instruction ) -> bool {
274
- for inst in function. all_inst_iter ( ) {
262
+ fn get_invalid_args < ' a > ( function : & ' a Function ) -> impl Iterator < Item = Word > + ' a {
263
+ function. all_inst_iter ( ) . filter_map ( |inst| {
275
264
if inst. class . opcode == Op :: AccessChain {
276
- let inst_result = inst. result_id . unwrap ( ) ;
277
- if call
278
- . operands
279
- . iter ( )
280
- . any ( |op| * op == Operand :: IdRef ( inst_result) )
281
- {
282
- return true ;
283
- }
265
+ inst. result_id
266
+ } else {
267
+ None
284
268
}
285
- }
286
- false
269
+ } )
270
+ }
271
+
272
+ fn args_invalid ( invalid_args : & FxHashSet < Word > , call : & Instruction ) -> bool {
273
+ call. operands . iter ( ) . skip ( 1 ) . any ( |op| {
274
+ op. id_ref_any ( )
275
+ . map_or ( false , |arg| invalid_args. contains ( & arg) )
276
+ } )
287
277
}
288
278
289
279
// Steps:
@@ -294,11 +284,10 @@ fn args_invalid(function: &Function, call: &Instruction) -> bool {
294
284
295
285
struct Inliner < ' m , ' map > {
296
286
header : & ' m mut ModuleHeader ,
297
- types_global_values : & ' m mut Vec < Instruction > ,
298
287
void : Word ,
299
- ptr_map : FxHashMap < Word , Word > ,
300
288
functions : & ' map FunctionMap ,
301
289
needs_inline : & ' map [ bool ] ,
290
+ invalid_args : FxHashSet < Word > ,
302
291
}
303
292
304
293
impl Inliner < ' _ , ' _ > {
@@ -308,25 +297,6 @@ impl Inliner<'_, '_> {
308
297
result
309
298
}
310
299
311
- fn ptr_ty ( & mut self , pointee : Word ) -> Word {
312
- let existing = self . ptr_map . get ( & pointee) ;
313
- if let Some ( existing) = existing {
314
- return * existing;
315
- }
316
- let inst_id = self . id ( ) ;
317
- self . types_global_values . push ( Instruction :: new (
318
- Op :: TypePointer ,
319
- None ,
320
- Some ( inst_id) ,
321
- vec ! [
322
- Operand :: StorageClass ( StorageClass :: Function ) ,
323
- Operand :: IdRef ( pointee) ,
324
- ] ,
325
- ) ) ;
326
- self . ptr_map . insert ( pointee, inst_id) ;
327
- inst_id
328
- }
329
-
330
300
fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
331
301
let mut function = take ( & mut functions[ index] ) ;
332
302
let mut block_idx = 0 ;
@@ -361,8 +331,8 @@ impl Inliner<'_, '_> {
361
331
self . functions [ & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ] ,
362
332
)
363
333
} )
364
- . find ( |( index , inst, func_idx) | {
365
- self . needs_inline [ * func_idx] || args_invalid ( caller , inst)
334
+ . find ( |( _ , inst, func_idx) | {
335
+ self . needs_inline [ * func_idx] || args_invalid ( & self . invalid_args , inst)
366
336
} ) ;
367
337
let ( call_index, call_inst, callee_idx) = match call {
368
338
None => return false ,
@@ -390,18 +360,23 @@ impl Inliner<'_, '_> {
390
360
} ) ;
391
361
let mut rewrite_rules = callee_parameters. zip ( call_arguments) . collect ( ) ;
392
362
393
- let return_variable = if call_result_type. is_some ( ) {
394
- Some ( self . id ( ) )
395
- } else {
396
- None
397
- } ;
398
363
let return_jump = self . id ( ) ;
399
364
// Rewrite OpReturns of the callee.
400
- let mut inlined_blocks = get_inlined_blocks ( callee, return_variable , return_jump) ;
365
+ let ( mut inlined_blocks, phi_pairs ) = get_inlined_blocks ( callee, return_jump) ;
401
366
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
402
367
// fn is inlined multiple times.
403
368
self . add_clone_id_rules ( & mut rewrite_rules, & inlined_blocks) ;
369
+ // If any of the OpReturns were invalid, return will also be invalid.
370
+ for ( value, _) in & phi_pairs {
371
+ if self . invalid_args . contains ( value) {
372
+ self . invalid_args . insert ( call_result_id) ;
373
+ self . invalid_args
374
+ . insert ( * rewrite_rules. get ( value) . unwrap_or ( value) ) ;
375
+ }
376
+ }
404
377
apply_rewrite_rules ( & rewrite_rules, & mut inlined_blocks) ;
378
+ // unnecessary: invalidate_more_args(&rewrite_rules, &mut self.invalid_args);
379
+ // as no values from inside the inlined function ever make it directly out.
405
380
406
381
// Split the block containing the OpFunctionCall into two, around the call.
407
382
let mut post_call_block_insts = caller. blocks [ block_idx]
@@ -411,32 +386,27 @@ impl Inliner<'_, '_> {
411
386
let call = caller. blocks [ block_idx] . instructions . pop ( ) . unwrap ( ) ;
412
387
assert ! ( call. class. opcode == Op :: FunctionCall ) ;
413
388
414
- if let Some ( call_result_type) = call_result_type {
415
- // Generate the storage space for the return value: Do this *after* the split above,
416
- // because if block_idx=0, inserting a variable here shifts call_index.
417
- insert_opvariable (
418
- & mut caller. blocks [ 0 ] ,
419
- self . ptr_ty ( call_result_type) ,
420
- return_variable. unwrap ( ) ,
421
- ) ;
422
- }
423
-
424
389
// Move the variables over from the inlined function to here.
425
390
let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
426
391
// TODO: OpLine handling
427
392
let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
428
393
// Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
429
394
// it, and we maintain the invariant that current block has finished processing.
430
- let jump_to = self . id ( ) ;
395
+ let first_block_id = self . id ( ) ;
431
396
inlined_blocks[ 0 ] = Block {
432
- label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
397
+ label : Some ( Instruction :: new (
398
+ Op :: Label ,
399
+ None ,
400
+ Some ( first_block_id) ,
401
+ vec ! [ ] ,
402
+ ) ) ,
433
403
instructions : callee_header. split_off ( num_variables) ,
434
404
} ;
435
405
caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
436
406
Op :: Branch ,
437
407
None ,
438
408
None ,
439
- vec ! [ Operand :: IdRef ( jump_to ) ] ,
409
+ vec ! [ Operand :: IdRef ( first_block_id ) ] ,
440
410
) ) ;
441
411
// Move the OpVariables of the callee to the caller.
442
412
insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
@@ -447,10 +417,17 @@ impl Inliner<'_, '_> {
447
417
post_call_block_insts. insert (
448
418
0 ,
449
419
Instruction :: new (
450
- Op :: Load ,
420
+ Op :: Phi ,
451
421
Some ( call_result_type) ,
452
422
Some ( call_result_id) ,
453
- vec ! [ Operand :: IdRef ( return_variable. unwrap( ) ) ] ,
423
+ phi_pairs
424
+ . into_iter ( )
425
+ . flat_map ( |( value, parent) | {
426
+ use std:: iter;
427
+ iter:: once ( Operand :: IdRef ( * rewrite_rules. get ( & value) . unwrap_or ( & value) ) )
428
+ . chain ( iter:: once ( Operand :: IdRef ( rewrite_rules[ & parent] ) ) )
429
+ } )
430
+ . collect ( ) ,
454
431
) ,
455
432
) ;
456
433
}
@@ -483,51 +460,21 @@ impl Inliner<'_, '_> {
483
460
}
484
461
}
485
462
486
- fn get_inlined_blocks (
487
- function : & Function ,
488
- return_variable : Option < Word > ,
489
- return_jump : Word ,
490
- ) -> Vec < Block > {
463
+ fn get_inlined_blocks ( function : & Function , return_jump : Word ) -> ( Vec < Block > , Vec < ( Word , Word ) > ) {
491
464
let mut blocks = function. blocks . clone ( ) ;
465
+ let mut phipairs = Vec :: new ( ) ;
492
466
for block in & mut blocks {
493
467
let last = block. instructions . last ( ) . unwrap ( ) ;
494
468
if let Op :: Return | Op :: ReturnValue = last. class . opcode {
495
469
if Op :: ReturnValue == last. class . opcode {
496
470
let return_value = last. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) ;
497
- block. instructions . insert (
498
- block. instructions . len ( ) - 1 ,
499
- Instruction :: new (
500
- Op :: Store ,
501
- None ,
502
- None ,
503
- vec ! [
504
- Operand :: IdRef ( return_variable. unwrap( ) ) ,
505
- Operand :: IdRef ( return_value) ,
506
- ] ,
507
- ) ,
508
- ) ;
509
- } else {
510
- assert ! ( return_variable. is_none( ) ) ;
471
+ phipairs. push ( ( return_value, block. label_id ( ) . unwrap ( ) ) )
511
472
}
512
473
* block. instructions . last_mut ( ) . unwrap ( ) =
513
474
Instruction :: new ( Op :: Branch , None , None , vec ! [ Operand :: IdRef ( return_jump) ] ) ;
514
475
}
515
476
}
516
- blocks
517
- }
518
-
519
- fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
520
- let index = block
521
- . instructions
522
- . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
523
-
524
- let inst = Instruction :: new (
525
- Op :: Variable ,
526
- Some ( ptr_ty) ,
527
- Some ( result_id) ,
528
- vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
529
- ) ;
530
- block. instructions . insert ( index, inst)
477
+ ( blocks, phipairs)
531
478
}
532
479
533
480
fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
@@ -539,6 +486,7 @@ fn insert_opvariables(block: &mut Block, insts: Vec<Instruction>) {
539
486
540
487
fn fuse_trivial_branches ( function : & mut Function ) {
541
488
let all_preds = compute_preds ( & function. blocks ) ;
489
+ let mut rewrite_rules = FxHashMap :: default ( ) ;
542
490
' outer: for ( dest_block, mut preds) in all_preds. iter ( ) . enumerate ( ) {
543
491
// if there's two trivial branches in a row, the middle one might get inlined before the
544
492
// last one, so when processing the last one, skip through to the first one.
@@ -555,12 +503,22 @@ fn fuse_trivial_branches(function: &mut Function) {
555
503
let pred_insts = & function. blocks [ pred] . instructions ;
556
504
if pred_insts. last ( ) . unwrap ( ) . class . opcode == Op :: Branch {
557
505
let mut dest_insts = take ( & mut function. blocks [ dest_block] . instructions ) ;
506
+ dest_insts. retain ( |inst| {
507
+ if inst. class . opcode == Op :: Phi {
508
+ assert_eq ! ( inst. operands. len( ) , 2 ) ;
509
+ rewrite_rules. insert ( inst. result_id . unwrap ( ) , inst. operands [ 0 ] . unwrap_id_ref ( ) ) ;
510
+ false
511
+ } else {
512
+ true
513
+ }
514
+ } ) ;
558
515
let pred_insts = & mut function. blocks [ pred] . instructions ;
559
516
pred_insts. pop ( ) ; // pop the branch
560
517
pred_insts. append ( & mut dest_insts) ;
561
518
}
562
519
}
563
520
function. blocks . retain ( |b| !b. instructions . is_empty ( ) ) ;
521
+ apply_rewrite_rules ( & rewrite_rules, & mut function. blocks ) ;
564
522
}
565
523
566
524
fn compute_preds ( blocks : & [ Block ] ) -> Vec < Vec < usize > > {
0 commit comments