@@ -13,20 +13,24 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet};
13
13
use rustc_session:: Session ;
14
14
use std:: mem:: take;
15
15
16
- type FunctionMap = FxHashMap < Word , Function > ;
16
+ type FunctionMap = FxHashMap < Word , usize > ;
17
17
18
18
pub fn inline ( sess : & Session , module : & mut Module ) -> super :: Result < ( ) > {
19
+ let ( disallowed_argument_types, disallowed_return_types) =
20
+ compute_disallowed_argument_and_return_types ( module) ;
21
+ let mut to_delete: Vec < _ > = module
22
+ . functions
23
+ . iter ( )
24
+ . map ( |f| should_inline ( & disallowed_argument_types, & disallowed_return_types, f) )
25
+ . collect ( ) ;
19
26
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
20
- if module_has_recursion ( sess, module) {
21
- return Err ( rustc_errors:: ErrorReported ) ;
22
- }
27
+ let postorder = compute_function_postorder ( sess, module, & mut to_delete) ?;
23
28
let functions = module
24
29
. functions
25
30
. iter ( )
26
- . map ( |f| ( f. def_id ( ) . unwrap ( ) , f. clone ( ) ) )
31
+ . enumerate ( )
32
+ . map ( |( idx, f) | ( f. def_id ( ) . unwrap ( ) , idx) )
27
33
. collect ( ) ;
28
- let ( disallowed_argument_types, disallowed_return_types) =
29
- compute_disallowed_argument_and_return_types ( module) ;
30
34
let void = module
31
35
. types_global_values
32
36
. iter ( )
@@ -35,23 +39,6 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
35
39
. unwrap_or ( 0 ) ;
36
40
// Drop all the functions we'll be inlining. (This also means we won't waste time processing
37
41
// inlines in functions that will get inlined)
38
- let mut dropped_ids = FxHashSet :: default ( ) ;
39
- module. functions . retain ( |f| {
40
- if should_inline ( & disallowed_argument_types, & disallowed_return_types, f) {
41
- // TODO: We should insert all defined IDs in this function.
42
- dropped_ids. insert ( f. def_id ( ) . unwrap ( ) ) ;
43
- false
44
- } else {
45
- true
46
- }
47
- } ) ;
48
- // Drop OpName etc. for inlined functions
49
- module. debug_names . retain ( |inst| {
50
- !inst. operands . iter ( ) . any ( |op| {
51
- op. id_ref_any ( )
52
- . map_or ( false , |id| dropped_ids. contains ( & id) )
53
- } )
54
- } ) ;
55
42
let mut inliner = Inliner {
56
43
header : module. header . as_mut ( ) . unwrap ( ) ,
57
44
types_global_values : & mut module. types_global_values ,
@@ -60,77 +47,122 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
60
47
disallowed_argument_types : & disallowed_argument_types,
61
48
disallowed_return_types : & disallowed_return_types,
62
49
} ;
63
- for function in & mut module . functions {
64
- inliner. inline_fn ( function ) ;
65
- fuse_trivial_branches ( function ) ;
50
+ for index in postorder {
51
+ inliner. inline_fn ( & mut module . functions , index ) ;
52
+ fuse_trivial_branches ( & mut module . functions [ index ] ) ;
66
53
}
54
+ let mut dropped_ids = FxHashSet :: default ( ) ;
55
+ for i in ( 0 ..module. functions . len ( ) ) . rev ( ) {
56
+ if to_delete[ i] {
57
+ dropped_ids. insert ( module. functions . remove ( i) . def_id ( ) . unwrap ( ) ) ;
58
+ }
59
+ }
60
+ // Drop OpName etc. for inlined functions
61
+ module. debug_names . retain ( |inst| {
62
+ !inst. operands . iter ( ) . any ( |op| {
63
+ op. id_ref_any ( )
64
+ . map_or ( false , |id| dropped_ids. contains ( & id) )
65
+ } )
66
+ } ) ;
67
67
Ok ( ( ) )
68
68
}
69
69
70
- // https://stackoverflow.com/a/53995651
71
- fn module_has_recursion ( sess : & Session , module : & Module ) -> bool {
70
+ /// Topological sorting algorithm due to T. Cormen
71
+ /// Starts from module's entry points, so only reachable functions will be returned
72
+ /// in post-traversal order of DFS. For all unvisited functions `module.functions[i]`,
73
+ /// `to_delete[i]` is set to true.
74
+ fn compute_function_postorder (
75
+ sess : & Session ,
76
+ module : & Module ,
77
+ to_delete : & mut [ bool ] ,
78
+ ) -> super :: Result < Vec < usize > > {
72
79
let func_to_index: FxHashMap < Word , usize > = module
73
80
. functions
74
81
. iter ( )
75
82
. enumerate ( )
76
83
. map ( |( index, func) | ( func. def_id ( ) . unwrap ( ) , index) )
77
84
. collect ( ) ;
78
- let mut discovered = vec ! [ false ; module. functions. len( ) ] ;
79
- let mut finished = vec ! [ false ; module. functions. len( ) ] ;
85
+ /// Possible node states for cycle-discovering DFS.
86
+ #[ derive( Clone , PartialEq ) ]
87
+ enum NodeState {
88
+ /// Normal, not visited.
89
+ NotVisited ,
90
+ /// Currently being visited.
91
+ Discovered ,
92
+ /// DFS returned.
93
+ Finished ,
94
+ /// Not visited, entry point.
95
+ Entry ,
96
+ }
97
+ let mut states = vec ! [ NodeState :: NotVisited ; module. functions. len( ) ] ;
98
+ for opep in module. entry_points . iter ( ) {
99
+ let func_id = opep. operands [ 1 ] . unwrap_id_ref ( ) ;
100
+ states[ func_to_index[ & func_id] ] = NodeState :: Entry ;
101
+ }
80
102
let mut has_recursion = false ;
103
+ let mut postorder = vec ! [ ] ;
81
104
for index in 0 ..module. functions . len ( ) {
82
- if !discovered [ index ] && !finished [ index] {
105
+ if NodeState :: Entry == states [ index] {
83
106
visit (
84
107
sess,
85
108
module,
86
109
index,
87
- & mut discovered,
88
- & mut finished,
110
+ & mut states[ ..] ,
89
111
& mut has_recursion,
112
+ & mut postorder,
90
113
& func_to_index,
91
114
) ;
92
115
}
93
116
}
94
117
118
+ for index in 0 ..module. functions . len ( ) {
119
+ if NodeState :: NotVisited == states[ index] {
120
+ to_delete[ index] = true ;
121
+ }
122
+ }
123
+
95
124
fn visit (
96
125
sess : & Session ,
97
126
module : & Module ,
98
127
current : usize ,
99
- discovered : & mut Vec < bool > ,
100
- finished : & mut Vec < bool > ,
128
+ states : & mut [ NodeState ] ,
101
129
has_recursion : & mut bool ,
130
+ postorder : & mut Vec < usize > ,
102
131
func_to_index : & FxHashMap < Word , usize > ,
103
132
) {
104
- discovered [ current] = true ;
133
+ states [ current] = NodeState :: Discovered ;
105
134
106
135
for next in calls ( & module. functions [ current] , func_to_index) {
107
- if discovered[ next] {
108
- let names = get_names ( module) ;
109
- let current_name = get_name ( & names, module. functions [ current] . def_id ( ) . unwrap ( ) ) ;
110
- let next_name = get_name ( & names, module. functions [ next] . def_id ( ) . unwrap ( ) ) ;
111
- sess. err ( & format ! (
112
- "module has recursion, which is not allowed: `{}` calls `{}`" ,
113
- current_name, next_name
114
- ) ) ;
115
- * has_recursion = true ;
116
- break ;
117
- }
118
-
119
- if !finished[ next] {
120
- visit (
121
- sess,
122
- module,
123
- next,
124
- discovered,
125
- finished,
126
- has_recursion,
127
- func_to_index,
128
- ) ;
136
+ match states[ next] {
137
+ NodeState :: Discovered => {
138
+ let names = get_names ( module) ;
139
+ let current_name =
140
+ get_name ( & names, module. functions [ current] . def_id ( ) . unwrap ( ) ) ;
141
+ let next_name = get_name ( & names, module. functions [ next] . def_id ( ) . unwrap ( ) ) ;
142
+ sess. err ( & format ! (
143
+ "module has recursion, which is not allowed: `{}` calls `{}`" ,
144
+ current_name, next_name
145
+ ) ) ;
146
+ * has_recursion = true ;
147
+ break ;
148
+ }
149
+ NodeState :: NotVisited | NodeState :: Entry => {
150
+ visit (
151
+ sess,
152
+ module,
153
+ next,
154
+ states,
155
+ has_recursion,
156
+ postorder,
157
+ func_to_index,
158
+ ) ;
159
+ }
160
+ NodeState :: Finished => { }
129
161
}
130
162
}
131
163
132
- discovered [ current] = false ;
133
- finished [ current] = true ;
164
+ states [ current] = NodeState :: Finished ;
165
+ postorder . push ( current)
134
166
}
135
167
136
168
fn calls < ' a > (
@@ -146,7 +178,11 @@ fn module_has_recursion(sess: &Session, module: &Module) -> bool {
146
178
} )
147
179
}
148
180
149
- has_recursion
181
+ if has_recursion {
182
+ Err ( rustc_errors:: ErrorReported )
183
+ } else {
184
+ Ok ( postorder)
185
+ }
150
186
}
151
187
152
188
fn compute_disallowed_argument_and_return_types (
@@ -283,33 +319,39 @@ impl Inliner<'_, '_> {
283
319
inst_id
284
320
}
285
321
286
- fn inline_fn ( & mut self , function : & mut Function ) {
322
+ fn inline_fn ( & mut self , functions : & mut [ Function ] , index : usize ) {
323
+ let mut function = take ( & mut functions[ index] ) ;
287
324
let mut block_idx = 0 ;
288
325
while block_idx < function. blocks . len ( ) {
289
- // If we successfully inlined a block, then repeat processing on the same block, in
290
- // case the newly inlined block has more inlined calls.
291
- // TODO: This is quadratic
292
- if !self . inline_block ( function, block_idx) {
293
- block_idx += 1 ;
294
- }
326
+ // If we successfully inlined a block, then continue processing on the next block or its tail.
327
+ // TODO: this is quadratic in cases where [`Op::AccessChain`]s cascade into inner arguments.
328
+ // For the common case of "we knew which functions to inline", it is linear.
329
+ self . inline_block ( & mut function, & functions, block_idx) ;
330
+ block_idx += 1 ;
295
331
}
332
+ functions[ index] = function;
296
333
}
297
334
298
- fn inline_block ( & mut self , caller : & mut Function , block_idx : usize ) -> bool {
335
+ /// Inlines one block and returns whether inlining actually occurred.
336
+ /// After calling this, blocks[block_idx] is finished processing.
337
+ fn inline_block (
338
+ & mut self ,
339
+ caller : & mut Function ,
340
+ functions : & [ Function ] ,
341
+ block_idx : usize ,
342
+ ) -> bool {
299
343
// Find the first inlined OpFunctionCall
300
344
let call = caller. blocks [ block_idx]
301
345
. instructions
302
346
. iter ( )
303
347
. enumerate ( )
304
348
. filter ( |( _, inst) | inst. class . opcode == Op :: FunctionCall )
305
349
. map ( |( index, inst) | {
306
- (
307
- index,
308
- inst,
309
- self . functions
310
- . get ( & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) )
311
- . unwrap ( ) ,
312
- )
350
+ let idx = self
351
+ . functions
352
+ . get ( & inst. operands [ 0 ] . id_ref_any ( ) . unwrap ( ) )
353
+ . unwrap ( ) ;
354
+ ( index, inst, & functions[ * idx] )
313
355
} )
314
356
. find ( |( _, inst, f) | {
315
357
should_inline (
@@ -374,17 +416,23 @@ impl Inliner<'_, '_> {
374
416
) ;
375
417
}
376
418
377
- // Fuse the first block of the callee into the block of the caller. This is okay because
378
- // it's illegal to branch to the first BB in a function.
379
- let mut callee_header = inlined_blocks. remove ( 0 ) . instructions ;
419
+ // Move the variables over from the inlined function to here.
420
+ let mut callee_header = take ( & mut inlined_blocks[ 0 ] ) . instructions ;
380
421
// TODO: OpLine handling
381
- let num_variables = callee_header
382
- . iter ( )
383
- . position ( |inst| inst. class . opcode != Op :: Variable )
384
- . unwrap_or ( callee_header. len ( ) ) ;
385
- caller. blocks [ block_idx]
386
- . instructions
387
- . append ( & mut callee_header. split_off ( num_variables) ) ;
422
+ let num_variables = callee_header. partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
423
+ // Rather than fuse blocks, generate a new jump here. Branch fusing will take care of
424
+ // it, and we maintain the invariant that current block has finished processing.
425
+ let jump_to = self . id ( ) ;
426
+ inlined_blocks[ 0 ] = Block {
427
+ label : Some ( Instruction :: new ( Op :: Label , None , Some ( jump_to) , vec ! [ ] ) ) ,
428
+ instructions : callee_header. split_off ( num_variables) ,
429
+ } ;
430
+ caller. blocks [ block_idx] . instructions . push ( Instruction :: new (
431
+ Op :: Branch ,
432
+ None ,
433
+ None ,
434
+ vec ! [ Operand :: IdRef ( jump_to) ] ,
435
+ ) ) ;
388
436
// Move the OpVariables of the callee to the caller.
389
437
insert_opvariables ( & mut caller. blocks [ 0 ] , callee_header) ;
390
438
@@ -466,45 +514,22 @@ fn get_inlined_blocks(
466
514
fn insert_opvariable ( block : & mut Block , ptr_ty : Word , result_id : Word ) {
467
515
let index = block
468
516
. instructions
469
- . iter ( )
470
- . enumerate ( )
471
- . find_map ( |( index, inst) | {
472
- if inst. class . opcode != Op :: Variable {
473
- Some ( index)
474
- } else {
475
- None
476
- }
477
- } ) ;
517
+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
518
+
478
519
let inst = Instruction :: new (
479
520
Op :: Variable ,
480
521
Some ( ptr_ty) ,
481
522
Some ( result_id) ,
482
523
vec ! [ Operand :: StorageClass ( StorageClass :: Function ) ] ,
483
524
) ;
484
- match index {
485
- Some ( index) => block. instructions . insert ( index, inst) ,
486
- None => block. instructions . push ( inst) ,
487
- }
525
+ block. instructions . insert ( index, inst)
488
526
}
489
527
490
- fn insert_opvariables ( block : & mut Block , mut insts : Vec < Instruction > ) {
528
+ fn insert_opvariables ( block : & mut Block , insts : Vec < Instruction > ) {
491
529
let index = block
492
530
. instructions
493
- . iter ( )
494
- . enumerate ( )
495
- . find_map ( |( index, inst) | {
496
- if inst. class . opcode != Op :: Variable {
497
- Some ( index)
498
- } else {
499
- None
500
- }
501
- } ) ;
502
- match index {
503
- Some ( index) => {
504
- block. instructions . splice ( index..index, insts) ;
505
- }
506
- None => block. instructions . append ( & mut insts) ,
507
- }
531
+ . partition_point ( |inst| inst. class . opcode == Op :: Variable ) ;
532
+ block. instructions . splice ( index..index, insts) ;
508
533
}
509
534
510
535
fn fuse_trivial_branches ( function : & mut Function ) {
0 commit comments