@@ -14,43 +14,31 @@ use rustc_errors::ErrorGuaranteed;
14
14
use rustc_session:: Session ;
15
15
use std:: mem:: take;
16
16
17
- type FunctionMap = FxHashMap < Word , Function > ;
17
+ type FunctionMap = FxHashMap < Word , usize > ;
18
18
19
19
pub fn inline ( sess : & Session , module : & mut Module ) -> super :: Result < ( ) > {
20
+ let ( disallowed_argument_types, disallowed_return_types) =
21
+ compute_disallowed_argument_and_return_types ( module) ;
22
+ let mut to_delete: Vec < _ > = module
23
+ . functions
24
+ . iter ( )
25
+ . map ( |f| should_inline ( & disallowed_argument_types, & disallowed_return_types, f) )
26
+ . collect ( ) ;
20
27
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
21
- deny_recursion_in_module ( sess, module) ?;
22
-
28
+ let postorder = compute_function_postorder ( sess, module, & mut to_delete) ?;
23
29
let functions = module
24
30
. functions
25
31
. iter ( )
26
- . map ( |f| ( f. def_id ( ) . unwrap ( ) , f. clone ( ) ) )
32
+ . enumerate ( )
33
+ . map ( |( idx, f) | ( f. def_id ( ) . unwrap ( ) , idx) )
27
34
. collect ( ) ;
28
- let ( disallowed_argument_types, disallowed_return_types) =
29
- compute_disallowed_argument_and_return_types ( module) ;
30
35
let void = module
31
36
. types_global_values
32
37
. iter ( )
33
38
. find ( |inst| inst. class . opcode == Op :: TypeVoid )
34
39
. map_or ( 0 , |inst| inst. result_id . unwrap ( ) ) ;
35
40
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
36
41
// inlines in functions that will get inlined)
37
- let mut dropped_ids = FxHashSet :: default ( ) ;
38
- module. functions . retain ( |f| {
39
- if should_inline ( & disallowed_argument_types, & disallowed_return_types, f) {
40
- // TODO: We should insert all defined IDs in this function.
41
- dropped_ids. insert ( f. def_id ( ) . unwrap ( ) ) ;
42
- false
43
- } else {
44
- true
45
- }
46
- } ) ;
47
- // Drop OpName etc. for inlined functions
48
- module. debug_names . retain ( |inst| {
49
- !inst. operands . iter ( ) . any ( |op| {
50
- op. id_ref_any ( )
51
- . map_or ( false , |id| dropped_ids. contains ( & id) )
52
- } )
53
- } ) ;
54
42
let mut inliner = Inliner {
55
43
header : module. header . as_mut ( ) . unwrap ( ) ,
56
44
types_global_values : & mut module. types_global_values ,
@@ -59,76 +47,121 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
59
47
disallowed_argument_types : & disallowed_argument_types,
60
48
disallowed_return_types : & disallowed_return_types,
61
49
} ;
62
- for function in & mut module. functions {
63
- inliner. inline_fn ( function) ;
64
- fuse_trivial_branches ( function) ;
50
+ for index in postorder {
51
+ inliner. inline_fn ( & mut module. functions , index) ;
52
+ fuse_trivial_branches ( & mut module. functions [ index] ) ;
53
+ }
54
+ let mut dropped_ids = FxHashSet :: default ( ) ;
55
+ for i in ( 0 ..module. functions . len ( ) ) . rev ( ) {
56
+ if to_delete[ i] {
57
+ dropped_ids. insert ( module. functions . remove ( i) . def_id ( ) . unwrap ( ) ) ;
58
+ }
65
59
}
60
+ // Drop OpName etc. for inlined functions
61
+ module. debug_names . retain ( |inst| {
62
+ !inst. operands . iter ( ) . any ( |op| {
63
+ op. id_ref_any ( )
64
+ . map_or ( false , |id| dropped_ids. contains ( & id) )
65
+ } )
66
+ } ) ;
66
67
Ok ( ( ) )
67
68
}
68
69
69
- // https://stackoverflow.com/a/53995651
70
- fn deny_recursion_in_module ( sess : & Session , module : & Module ) -> super :: Result < ( ) > {
70
+ /// Topological sorting algorithm due to T. Cormen
71
+ /// Starts from module's entry points, so only reachable functions will be returned
72
+ /// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`,
73
+ /// `to_delete[i]` is set to true.
74
+ fn compute_function_postorder (
75
+ sess : & Session ,
76
+ module : & Module ,
77
+ to_delete : & mut [ bool ] ,
78
+ ) -> super :: Result < Vec < usize > > {
71
79
let func_to_index: FxHashMap < Word , usize > = module
72
80
. functions
73
81
. iter ( )
74
82
. enumerate ( )
75
83
. map ( |( index, func) | ( func. def_id ( ) . unwrap ( ) , index) )
76
84
. collect ( ) ;
77
- let mut discovered = vec ! [ false ; module. functions. len( ) ] ;
78
- let mut finished = vec ! [ false ; module. functions. len( ) ] ;
85
+ /// Possible node states for cycle-discovering DFS.
86
+ #[ derive( Clone , PartialEq ) ]
87
+ enum NodeState {
88
+ /// Normal, not visited.
89
+ NotVisited ,
90
+ /// Currently being visited.
91
+ Discovered ,
92
+ /// DFS returned.
93
+ Finished ,
94
+ /// Not visited, entry point.
95
+ Entry ,
96
+ }
97
+ let mut states = vec ! [ NodeState :: NotVisited ; module. functions. len( ) ] ;
98
+ for opep in module. entry_points . iter ( ) {
99
+ let func_id = opep. operands [ 1 ] . unwrap_id_ref ( ) ;
100
+ states[ func_to_index[ & func_id] ] = NodeState :: Entry ;
101
+ }
79
102
let mut has_recursion = None ;
103
+ let mut postorder = vec ! [ ] ;
80
104
for index in 0 ..module. functions . len ( ) {
81
- if !discovered [ index ] && !finished [ index] {
105
+ if NodeState :: Entry == states [ index] {
82
106
visit (
83
107
sess,
84
108
module,
85
109
index,
86
- & mut discovered,
87
- & mut finished,
110
+ & mut states[ ..] ,
88
111
& mut has_recursion,
112
+ & mut postorder,
89
113
& func_to_index,
90
114
) ;
91
115
}
92
116
}
93
117
118
+ for index in 0 ..module. functions . len ( ) {
119
+ if NodeState :: NotVisited == states[ index] {
120
+ to_delete[ index] = true ;
121
+ }
122
+ }
123
+
94
124
fn visit (
95
125
sess : & Session ,
96
126
module : & Module ,
97
127
current : usize ,
98
- discovered : & mut Vec < bool > ,
99
- finished : & mut Vec < bool > ,
128
+ states : & mut [ NodeState ] ,
100
129
has_recursion : & mut Option < ErrorGuaranteed > ,
130
+ postorder : & mut Vec < usize > ,
101
131
func_to_index : & FxHashMap < Word , usize > ,
102
132
) {
103
- discovered [ current] = true ;
133
+ states [ current] = NodeState :: Discovered ;
104
134
105
135
for next in calls ( & module. functions [ current] , func_to_index) {
106
- if discovered[ next] {
107
- let names = get_names ( module) ;
108
- let current_name = get_name ( & names, module. functions [ current] . def_id ( ) . unwrap ( ) ) ;
109
- let next_name = get_name ( & names, module. functions [ next] . def_id ( ) . unwrap ( ) ) ;
110
- * has_recursion = Some ( sess. err ( & format ! (
111
- "module has recursion, which is not allowed: `{}` calls `{}`" ,
112
- current_name, next_name
113
- ) ) ) ;
114
- break ;
115
- }
116
-
117
- if !finished[ next] {
118
- visit (
119
- sess,
120
- module,
121
- next,
122
- discovered,
123
- finished,
124
- has_recursion,
125
- func_to_index,
126
- ) ;
136
+ match states[ next] {
137
+ NodeState :: Discovered => {
138
+ let names = get_names ( module) ;
139
+ let current_name =
140
+ get_name ( & names, module. functions [ current] . def_id ( ) . unwrap ( ) ) ;
141
+ let next_name = get_name ( & names, module. functions [ next] . def_id ( ) . unwrap ( ) ) ;
142
+ * has_recursion = Some ( sess. err ( & format ! (
143
+ "module has recursion, which is not allowed: `{}` calls `{}`" ,
144
+ current_name, next_name
145
+ ) ) ) ;
146
+ break ;
147
+ }
148
+ NodeState :: NotVisited | NodeState :: Entry => {
149
+ visit (
150
+ sess,
151
+ module,
152
+ next,
153
+ states,
154
+ has_recursion,
155
+ postorder,
156
+ func_to_index,
157
+ ) ;
158
+ }
159
+ NodeState :: Finished => { }
127
160
}
128
161
}
129
162
130
- discovered [ current] = false ;
131
- finished [ current] = true ;
163
+ states [ current] = NodeState :: Finished ;
164
+ postorder . push ( current)
132
165
}
133
166
134
167
fn calls < ' a > (
@@ -146,7 +179,7 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()
146
179
147
180
match has_recursion {
148
181
Some ( err) => Err ( err) ,
149
- None => Ok ( ( ) ) ,
182
+ None => Ok ( postorder ) ,
150
183
}
151
184
}
152
185
@@ -284,33 +317,39 @@ impl Inliner<'_, '_> {
284
317
inst_id
285
318
}
286
319
287
- fn inline_fn ( & mut self , function : & mut Function ) {
320
+ fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
321
+ let mut function = take ( & mut functions[ index] ) ;
288
322
let mut block_idx = 0 ;
289
323
while block_idx < function. blocks . len ( ) {
290
- // If we successfully inlined a block, then repeat processing on the same block, in
291
- // case the newly inlined block has more inlined calls.
292
- // TODO: This is quadratic
293
- if !self . inline_block ( function, block_idx) {
294
- block_idx += 1 ;
295
- }
324
+ // If we successfully inlined a block, then continue processing on the next block or its tail.
325
+ // TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments.
326
+ // For the common case of "we knew which functions to inline", it is linear.
327
+ self . inline_block ( & mut function, & functions, block_idx) ;
328
+ block_idx += 1 ;
296
329
}
330
+ functions[ index] = function;
297
331
}
298
332
299
- fn inline_block ( & mut self , caller : & mut Function , block_idx : usize ) -> bool {
333
+ /// Inlines one block and returns whether inlining actually occurred.
334
+ /// After calling this, blocks[block_idx] is finished processing.
335
+ fn inline_block (
336
+ & mut self ,
337
+ caller : & mut Function ,
338
+ functions : & [ Function ] ,
339
+ block_idx : usize ,
340
+ ) -> bool {
300
341
// Find the first inlined OpFunctionCall
301
342
let call = caller. blocks [ block_idx]
302
343
. instructions
303
344
. iter ( )
304
345
. enumerate ( )
305
346
. filter ( |( _, inst) | inst. class . opcode == Op :: FunctionCall )
306
347
. map ( |( index, inst) | {
307
- (
308
- index,
309
- inst,
310
- self . functions
311
- . get ( & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) )
312
- . unwrap ( ) ,
313
- )
348
+ let idx = self
349
+ . functions
350
+ . get ( & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) )
351
+ . unwrap ( ) ;
352
+ ( index, inst, & functions[ * idx] )
314
353
} )
315
354
. find ( |( _, inst, f) | {
316
355
should_inline (
@@ -375,17 +414,23 @@ impl Inliner<'_, '_> {
375
414
) ;
376
415
}
377
416
378
- // Fuse the first block of the callee into the block of the caller. This is okay because
379
- // it's illegal to branch to the first BB in a function.
380
- let mut callee_header = inlined_blocks. remove ( 0 ) . instructions ;
417
+ // Move the variables over from the inlined function to here.
418
+ let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
381
419
// TODO: OpLine handling
382
- let num_variables = callee_header
383
- . iter ( )
384
- . position ( |inst| inst. class . opcode != Op :: Variable )
385
- . unwrap_or ( callee_header. len ( ) ) ;
386
- caller. blocks [ block_idx]
387
- . instructions
388
- . append ( & mut callee_header. split_off ( num_variables) ) ;
420
+ let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
421
+ // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
422
+ // it, and we maintain the invariant that current block has finished processing.
423
+ let jump_to = self . id ( ) ;
424
+ inlined_blocks[ 0 ] = Block {
425
+ label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
426
+ instructions : callee_header. split_off ( num_variables) ,
427
+ } ;
428
+ caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
429
+ Op :: Branch ,
430
+ None ,
431
+ None ,
432
+ vec ! [ Operand :: IdRef ( jump_to) ] ,
433
+ ) ) ;
389
434
// Move the OpVariables of the callee to the caller.
390
435
insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
391
436
@@ -467,45 +512,22 @@ fn get_inlined_blocks(
467
512
fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
468
513
let index = block
469
514
. instructions
470
- . iter ( )
471
- . enumerate ( )
472
- . find_map ( |( index, inst) | {
473
- if inst. class . opcode != Op :: Variable {
474
- Some ( index)
475
- } else {
476
- None
477
- }
478
- } ) ;
515
+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
516
+
479
517
let inst = Instruction :: new (
480
518
Op :: Variable ,
481
519
Some ( ptr_ty) ,
482
520
Some ( result_id) ,
483
521
vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
484
522
) ;
485
- match index {
486
- Some ( index) => block. instructions . insert ( index, inst) ,
487
- None => block. instructions . push ( inst) ,
488
- }
523
+ block. instructions . insert ( index, inst)
489
524
}
490
525
491
- fn insert_opvariables ( block : & mut Block , mut insts : Vec < Instruction > ) {
526
+ fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
492
527
let index = block
493
528
. instructions
494
- . iter ( )
495
- . enumerate ( )
496
- . find_map ( |( index, inst) | {
497
- if inst. class . opcode != Op :: Variable {
498
- Some ( index)
499
- } else {
500
- None
501
- }
502
- } ) ;
503
- match index {
504
- Some ( index) => {
505
- block. instructions . splice ( index..index, insts) ;
506
- }
507
- None => block. instructions . append ( & mut insts) ,
508
- }
529
+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
530
+ block. instructions . splice ( index..index, insts) ;
509
531
}
510
532
511
533
fn fuse_trivial_branches ( function : & mut Function ) {
0 commit comments