1
1
use rspirv:: binary:: Disassemble ;
2
2
use rspirv:: dr:: { Instruction , Module , Operand } ;
3
- use rspirv:: spirv:: Op ;
3
+ use rspirv:: spirv:: { Op , StorageClass } ;
4
4
use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
5
5
use rustc_session:: Session ;
6
6
7
+ // bool is if this needs stored
8
+ #[ derive( Debug , Clone , PartialEq ) ]
9
+ struct NormalizedInstructions {
10
+ vars : Vec < Instruction > ,
11
+ insts : Vec < Instruction > ,
12
+ root : u32 ,
13
+ }
14
+
15
+ impl NormalizedInstructions {
16
+ fn new ( id : u32 ) -> Self {
17
+ NormalizedInstructions {
18
+ vars : Vec :: new ( ) ,
19
+ insts : Vec :: new ( ) ,
20
+ root : id,
21
+ }
22
+ }
23
+
24
+ fn extend ( & mut self , o : NormalizedInstructions ) {
25
+ self . vars . extend ( o. vars ) ;
26
+ self . insts . extend ( o. insts ) ;
27
+ }
28
+
29
+ fn is_empty ( & self ) -> bool {
30
+ self . insts . is_empty ( ) && self . vars . is_empty ( )
31
+ }
32
+
33
+ fn fix_ids ( & mut self , bound : & mut u32 , new_root : u32 ) {
34
+ let mut id_map: FxHashMap < u32 , u32 > = FxHashMap :: default ( ) ;
35
+ id_map. insert ( self . root , new_root) ;
36
+ for inst in & mut self . vars {
37
+ Self :: fix_instruction ( self . root , inst, & mut id_map, bound, new_root) ;
38
+ }
39
+ for inst in & mut self . insts {
40
+ Self :: fix_instruction ( self . root , inst, & mut id_map, bound, new_root) ;
41
+ }
42
+ }
43
+
44
+ fn fix_instruction (
45
+ root : u32 ,
46
+ inst : & mut Instruction ,
47
+ id_map : & mut FxHashMap < u32 , u32 > ,
48
+ bound : & mut u32 ,
49
+ new_root : u32 ,
50
+ ) {
51
+ for op in & mut inst. operands {
52
+ match op {
53
+ Operand :: IdRef ( id) => match id_map. get ( id) {
54
+ Some ( new_id) => {
55
+ * id = * new_id;
56
+ }
57
+ _ => { }
58
+ } ,
59
+ _ => { }
60
+ }
61
+ }
62
+ if let Some ( id) = & mut inst. result_id {
63
+ if * id != root {
64
+ id_map. insert ( * id, * bound) ;
65
+ * id = * bound;
66
+ * bound += 1 ;
67
+ } else {
68
+ * id = new_root;
69
+ }
70
+ }
71
+ }
72
+ }
73
+
7
74
#[ derive( Debug , Clone , PartialEq ) ]
8
75
enum FunctionArg {
9
76
Invalid ,
10
- Insts ( Vec < Instruction > ) ,
77
+ Insts ( NormalizedInstructions ) ,
11
78
}
12
79
13
80
pub fn inline_global_varaibles ( sess : & Session , module : & mut Module ) -> super :: Result < ( ) > {
@@ -36,14 +103,30 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
36
103
}
37
104
// then we keep track of which function parameter are always called with the same expression that only uses global variables
38
105
let mut function_args: FxHashMap < ( u32 , u32 ) , FunctionArg > = FxHashMap :: default ( ) ;
106
+ let mut bound = module. header . as_ref ( ) . unwrap ( ) . bound ;
39
107
for caller in & module. functions {
40
108
let mut insts: FxHashMap < u32 , Instruction > = FxHashMap :: default ( ) ;
109
+ // for variables that only stored once and it's stored as a ref
110
+ let mut ref_stores: FxHashMap < u32 , Option < u32 > > = FxHashMap :: default ( ) ;
41
111
for block in & caller. blocks {
42
112
for inst in & block. instructions {
43
113
if inst. result_id . is_some ( ) {
44
114
insts. insert ( inst. result_id . unwrap ( ) , inst. clone ( ) ) ;
45
115
}
46
- if inst. class . opcode == Op :: FunctionCall {
116
+ if inst. class . opcode == Op :: Store {
117
+ if let Operand :: IdRef ( to) = inst. operands [ 0 ] {
118
+ if let Operand :: IdRef ( from) = inst. operands [ 1 ] {
119
+ match ref_stores. get ( & to) {
120
+ None => {
121
+ ref_stores. insert ( to, Some ( from) ) ;
122
+ }
123
+ Some ( _) => {
124
+ ref_stores. insert ( to, None ) ;
125
+ }
126
+ }
127
+ }
128
+ }
129
+ } else if inst. class . opcode == Op :: FunctionCall {
47
130
let function_id = match & inst. operands [ 0 ] {
48
131
& Operand :: IdRef ( w) => w,
49
132
_ => panic ! ( ) ,
@@ -52,16 +135,19 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
52
135
let key = ( function_id, i as u32 - 1 ) ;
53
136
match & inst. operands [ i] {
54
137
& Operand :: IdRef ( w) => match & function_args. get ( & key) {
55
- None => match get_const_arg_insts ( & variables, & insts, w) {
56
- Some ( insts) => {
57
- function_args. insert ( key, FunctionArg :: Insts ( insts) ) ;
58
- }
59
- None => {
60
- function_args. insert ( key, FunctionArg :: Invalid ) ;
138
+ None => {
139
+ match get_const_arg_insts ( bound, & variables, & insts, & ref_stores, w) {
140
+ Some ( insts) => {
141
+ function_args. insert ( key, FunctionArg :: Insts ( insts) ) ;
142
+ }
143
+ None => {
144
+ function_args. insert ( key, FunctionArg :: Invalid ) ;
145
+ }
61
146
}
62
- } ,
147
+ }
63
148
Some ( FunctionArg :: Insts ( w2) ) => {
64
- let new_insts = get_const_arg_insts ( & variables, & insts, w) ;
149
+ let new_insts =
150
+ get_const_arg_insts ( bound, & variables, & insts, & ref_stores, w) ;
65
151
match new_insts {
66
152
Some ( new_insts) => {
67
153
if new_insts != * w2 {
@@ -93,11 +179,10 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
93
179
if function_args. is_empty ( ) {
94
180
return Ok ( false ) ;
95
181
}
96
- let mut bound = module. header . as_ref ( ) . unwrap ( ) . bound ;
97
182
for function in & mut module. functions {
98
183
let def = function. def . as_mut ( ) . unwrap ( ) ;
99
184
let fid = def. result_id . unwrap ( ) ;
100
- let mut insts: Vec < Instruction > = Vec :: new ( ) ;
185
+ let mut insts = NormalizedInstructions :: new ( 0 ) ;
101
186
let mut j: u32 = 0 ;
102
187
let mut i = 0 ;
103
188
let mut removed_indexes: Vec < u32 > = Vec :: new ( ) ;
@@ -108,10 +193,7 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
108
193
Some ( FunctionArg :: Insts ( arg) ) => {
109
194
let parameter = function. parameters . remove ( i) ;
110
195
let mut arg = arg. clone ( ) ;
111
- arg. reverse ( ) ;
112
- insts_replacing_captured_ids ( & mut arg, & mut bound) ;
113
- let index = arg. len ( ) - 1 ;
114
- arg[ index] . result_id = parameter. result_id ;
196
+ arg. fix_ids ( & mut bound, parameter. result_id . unwrap ( ) ) ;
115
197
insts. extend ( arg) ;
116
198
removed_indexes. push ( j) ;
117
199
removed = true ;
@@ -132,15 +214,16 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
132
214
for i in removed_indexes. iter ( ) . rev ( ) {
133
215
let i = * i as usize + 1 ;
134
216
function_type. operands . remove ( i) ;
135
- function_type. result_id = Some ( tid) ;
136
217
}
218
+ function_type. result_id = Some ( tid) ;
137
219
def. operands [ 1 ] = Operand :: IdRef ( tid) ;
138
220
module. types_global_values . push ( function_type) ;
139
221
}
140
222
}
141
223
// callee side. insert initialization instructions, which reuse the ids of the removed parameters
142
224
if !function. blocks . is_empty ( ) {
143
225
let first_block = & mut function. blocks [ 0 ] ;
226
+ first_block. instructions . splice ( 0 ..0 , insts. vars ) ;
144
227
// skip some instructions that must be at top of block
145
228
let mut i = 0 ;
146
229
loop {
@@ -154,7 +237,7 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
154
237
}
155
238
i += 1 ;
156
239
}
157
- first_block. instructions . splice ( i..i, insts) ;
240
+ first_block. instructions . splice ( i..i, insts. insts ) ;
158
241
}
159
242
// caller side, remove parameters from function call
160
243
for block in & mut function. blocks {
@@ -181,76 +264,90 @@ fn inline_global_varaibles_rec(sess: &Session, module: &mut Module) -> super::Re
181
264
Ok ( true )
182
265
}
183
266
184
- fn insts_replacing_captured_ids ( arg : & mut Vec < Instruction > , bound : & mut u32 ) {
185
- let mut id_map: FxHashMap < u32 , u32 > = FxHashMap :: default ( ) ;
186
- for ins in arg {
187
- if let Some ( id) = & mut ins. result_id {
188
- for op in & mut ins. operands {
189
- match op {
190
- Operand :: IdRef ( id) => match id_map. get ( id) {
191
- Some ( new_id) => {
192
- * id = * new_id;
193
- }
194
- _ => { }
195
- } ,
196
- _ => { }
197
- }
198
- }
199
- id_map. insert ( * id, * bound) ;
200
- * id = * bound;
201
- * bound += 1 ;
202
- }
203
- }
204
- }
205
-
206
267
fn get_const_arg_operands (
207
268
variables : & FxHashSet < u32 > ,
208
269
insts : & FxHashMap < u32 , Instruction > ,
270
+ ref_stores : & FxHashMap < u32 , Option < u32 > > ,
209
271
operand : & Operand ,
210
- ) -> Option < Vec < Instruction > > {
272
+ ) -> Option < NormalizedInstructions > {
211
273
match operand {
212
274
Operand :: IdRef ( id) => {
213
- let insts = get_const_arg_insts ( variables, insts, * id) ?;
275
+ let insts = get_const_arg_insts_rec ( variables, insts, ref_stores , * id) ?;
214
276
return Some ( insts) ;
215
277
}
216
- Operand :: LiteralInt32 ( _) => { } ,
217
- Operand :: LiteralInt64 ( _) => { } ,
218
- Operand :: LiteralFloat32 ( _) => { } ,
219
- Operand :: LiteralFloat64 ( _) => { } ,
220
- Operand :: LiteralExtInstInteger ( _) => { } ,
221
- Operand :: LiteralSpecConstantOpInteger ( _) => { } ,
222
- Operand :: LiteralString ( _) => { } ,
278
+ Operand :: LiteralInt32 ( _) => { }
279
+ Operand :: LiteralInt64 ( _) => { }
280
+ Operand :: LiteralFloat32 ( _) => { }
281
+ Operand :: LiteralFloat64 ( _) => { }
282
+ Operand :: LiteralExtInstInteger ( _) => { }
283
+ Operand :: LiteralSpecConstantOpInteger ( _) => { }
284
+ Operand :: LiteralString ( _) => { }
223
285
_ => {
224
286
// TOOD add more cases
225
287
return None ;
226
288
}
227
289
}
228
- return Some ( Vec :: new ( ) ) ;
290
+ return Some ( NormalizedInstructions :: new ( 0 ) ) ;
229
291
}
230
292
231
293
fn get_const_arg_insts (
294
+ mut bound : u32 ,
295
+ variables : & FxHashSet < u32 > ,
296
+ insts : & FxHashMap < u32 , Instruction > ,
297
+ ref_stores : & FxHashMap < u32 , Option < u32 > > ,
298
+ id : u32 ,
299
+ ) -> Option < NormalizedInstructions > {
300
+ let mut res = get_const_arg_insts_rec ( variables, insts, ref_stores, id) ?;
301
+ res. insts . reverse ( ) ;
302
+ // the bound passed in is always the same
303
+ // we need to normalize the ids, so they are the same when compared
304
+ let fake_root = bound;
305
+ bound += 1 ;
306
+ res. fix_ids ( & mut bound, fake_root) ;
307
+ res. root = fake_root;
308
+ Some ( res)
309
+ }
310
+
311
+ fn get_const_arg_insts_rec (
232
312
variables : & FxHashSet < u32 > ,
233
313
insts : & FxHashMap < u32 , Instruction > ,
314
+ ref_stores : & FxHashMap < u32 , Option < u32 > > ,
234
315
id : u32 ,
235
- ) -> Option < Vec < Instruction > > {
236
- let mut result: Vec < Instruction > = Vec :: new ( ) ;
316
+ ) -> Option < NormalizedInstructions > {
317
+ let mut result = NormalizedInstructions :: new ( id ) ;
237
318
if variables. contains ( & id) {
238
319
return Some ( result) ;
239
320
}
240
321
let par: & Instruction = insts. get ( & id) ?;
241
322
if par. class . opcode == Op :: AccessChain {
242
- result. push ( par. clone ( ) ) ;
323
+ result. insts . push ( par. clone ( ) ) ;
243
324
for oprand in & par. operands {
244
- let insts = get_const_arg_operands ( variables, insts, oprand) ?;
325
+ let insts = get_const_arg_operands ( variables, insts, ref_stores , oprand) ?;
245
326
result. extend ( insts) ;
246
327
}
247
328
} else if par. class . opcode == Op :: FunctionCall {
248
- result. push ( par. clone ( ) ) ;
329
+ result. insts . push ( par. clone ( ) ) ;
249
330
// skip first, first is function id
250
331
for oprand in & par. operands [ 1 ..] {
251
- let insts = get_const_arg_operands ( variables, insts, oprand) ?;
332
+ let insts = get_const_arg_operands ( variables, insts, ref_stores , oprand) ?;
252
333
result. extend ( insts) ;
253
334
}
335
+ } else if par. class . opcode == Op :: Variable {
336
+ result. vars . push ( par. clone ( ) ) ;
337
+ let stored = ref_stores. get ( & id) ?;
338
+ let stored = ( * stored) ?;
339
+ result. insts . push ( Instruction :: new (
340
+ Op :: Store ,
341
+ None ,
342
+ None ,
343
+ vec ! [ Operand :: IdRef ( id) , Operand :: IdRef ( stored) ] ,
344
+ ) ) ;
345
+ let new_insts = get_const_arg_insts_rec ( variables, insts, ref_stores, stored) ?;
346
+ result. extend ( new_insts) ;
347
+ } else if par. class . opcode == Op :: ArrayLength {
348
+ result. insts . push ( par. clone ( ) ) ;
349
+ let insts = get_const_arg_operands ( variables, insts, ref_stores, & par. operands [ 0 ] ) ?;
350
+ result. extend ( insts) ;
254
351
} else {
255
352
// TOOD add more cases
256
353
return None ;
0 commit comments