@@ -133,18 +133,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
133
133
134
134
let mut patch = MirPatch::new(body);
135
135
136
- // create temp to store second discriminant in, `_s` in example above
137
- let second_discriminant_temp =
138
- patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
136
+ let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
137
+ // create temp to store second discriminant in, `_s` in example above
138
+ let second_discriminant_temp =
139
+ patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
139
140
140
- patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
141
+ patch.add_statement(
142
+ parent_end,
143
+ StatementKind::StorageLive(second_discriminant_temp),
144
+ );
141
145
142
- // create assignment of discriminant
143
- patch.add_assign(
144
- parent_end,
145
- Place::from(second_discriminant_temp),
146
- Rvalue::Discriminant(opt_data.child_place),
147
- );
146
+ // create assignment of discriminant
147
+ patch.add_assign(
148
+ parent_end,
149
+ Place::from(second_discriminant_temp),
150
+ Rvalue::Discriminant(opt_data.child_place),
151
+ );
152
+ (
153
+ Some(second_discriminant_temp),
154
+ Operand::Move(Place::from(second_discriminant_temp)),
155
+ )
156
+ } else {
157
+ (None, Operand::Copy(opt_data.child_place))
158
+ };
148
159
149
160
// create temp to store inequality comparison between the two discriminants, `_t` in
150
161
// example above
@@ -153,11 +164,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
153
164
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
154
165
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
155
166
156
- // create inequality comparison between the two discriminants
157
- let comp_rvalue = Rvalue::BinaryOp(
158
- nequal,
159
- Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
160
- );
167
+ // create inequality comparison
168
+ let comp_rvalue =
169
+ Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
161
170
patch.add_statement(
162
171
parent_end,
163
172
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
@@ -193,8 +202,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
193
202
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
194
203
);
195
204
196
- // generate StorageDead for the second_discriminant_temp not in use anymore
197
- patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
205
+ if let Some(second_discriminant_temp) = second_discriminant_temp {
206
+ // generate StorageDead for the second_discriminant_temp not in use anymore
207
+ patch.add_statement(
208
+ parent_end,
209
+ StatementKind::StorageDead(second_discriminant_temp),
210
+ );
211
+ }
198
212
199
213
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
200
214
// the switch
@@ -222,6 +236,7 @@ struct OptimizationData<'tcx> {
222
236
child_place: Place<'tcx>,
223
237
child_ty: Ty<'tcx>,
224
238
child_source: SourceInfo,
239
+ need_hoist_discriminant: bool,
225
240
}
226
241
227
242
fn evaluate_candidate<'tcx>(
@@ -235,70 +250,128 @@ fn evaluate_candidate<'tcx>(
235
250
return None;
236
251
};
237
252
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
238
- if !bbs[targets.otherwise()].is_empty_unreachable() {
239
- // Someone could write code like this:
240
- // ```rust
241
- // let Q = val;
242
- // if discriminant(P) == otherwise {
243
- // let ptr = &mut Q as *mut _ as *mut u8;
244
- // // It may be difficult for us to effectively determine whether values are valid.
245
- // // Invalid values can come from all sorts of corners.
246
- // unsafe { *ptr = 10; }
247
- // }
248
- //
249
- // match P {
250
- // A => match Q {
251
- // A => {
252
- // // code
253
- // }
254
- // _ => {
255
- // // don't use Q
256
- // }
257
- // }
258
- // _ => {
259
- // // don't use Q
260
- // }
261
- // };
262
- // ```
263
- //
264
- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
265
- // of an invalid value, which is UB.
266
- // In order to fix this, **we would either need to show that the discriminant computation of
267
- // `place` is computed in all branches**.
268
- // FIXME(#95162) For the moment, we adopt a conservative approach and
269
- // consider only the `otherwise` branch has no statements and an unreachable terminator.
270
- return None;
271
- }
272
253
let (_, child) = targets.iter().next()?;
273
- let child_terminator = &bbs[child].terminator();
274
- let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
275
- &child_terminator.kind
254
+
255
+ let Terminator {
256
+ kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
257
+ source_info,
258
+ } = bbs[child].terminator()
276
259
else {
277
260
return None;
278
261
};
279
262
let child_ty = child_discr.ty(body.local_decls(), tcx);
280
263
if child_ty != parent_ty {
281
264
return None;
282
265
}
283
- let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {
266
+
267
+ // We only handle:
268
+ // ```
269
+ // bb4: {
270
+ // _8 = discriminant((_3.1: Enum1));
271
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
272
+ // }
273
+ // ```
274
+ // and
275
+ // ```
276
+ // bb2: {
277
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
278
+ // }
279
+ // ```
280
+ if bbs[child].statements.len() > 1 {
284
281
return None;
282
+ }
283
+
284
+ // When thie BB has exactly one statement, this statement should be discriminant.
285
+ let need_hoist_discriminant = bbs[child].statements.len() == 1;
286
+ let child_place = if need_hoist_discriminant {
287
+ if !bbs[targets.otherwise()].is_empty_unreachable() {
288
+ // Someone could write code like this:
289
+ // ```rust
290
+ // let Q = val;
291
+ // if discriminant(P) == otherwise {
292
+ // let ptr = &mut Q as *mut _ as *mut u8;
293
+ // // It may be difficult for us to effectively determine whether values are valid.
294
+ // // Invalid values can come from all sorts of corners.
295
+ // unsafe { *ptr = 10; }
296
+ // }
297
+ //
298
+ // match P {
299
+ // A => match Q {
300
+ // A => {
301
+ // // code
302
+ // }
303
+ // _ => {
304
+ // // don't use Q
305
+ // }
306
+ // }
307
+ // _ => {
308
+ // // don't use Q
309
+ // }
310
+ // };
311
+ // ```
312
+ //
313
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
314
+ // invalid value, which is UB.
315
+ // In order to fix this, **we would either need to show that the discriminant computation of
316
+ // `place` is computed in all branches**.
317
+ // FIXME(#95162) For the moment, we adopt a conservative approach and
318
+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
319
+ return None;
320
+ }
321
+ // Handle:
322
+ // ```
323
+ // bb4: {
324
+ // _8 = discriminant((_3.1: Enum1));
325
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
326
+ // }
327
+ // ```
328
+ let [
329
+ Statement {
330
+ kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
331
+ ..
332
+ },
333
+ ] = bbs[child].statements.as_slice()
334
+ else {
335
+ return None;
336
+ };
337
+ *child_place
338
+ } else {
339
+ // Handle:
340
+ // ```
341
+ // bb2: {
342
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
343
+ // }
344
+ // ```
345
+ let Operand::Copy(child_place) = child_discr else {
346
+ return None;
347
+ };
348
+ *child_place
285
349
};
286
- let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
287
- return None;
350
+ let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
351
+ {
352
+ child_targets.otherwise()
353
+ } else {
354
+ targets.otherwise()
288
355
};
289
- let destination = child_targets.otherwise();
290
356
291
357
// Verify that the optimization is legal for each branch
292
358
for (value, child) in targets.iter() {
293
- if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
359
+ if !verify_candidate_branch(
360
+ &bbs[child],
361
+ value,
362
+ child_place,
363
+ destination,
364
+ need_hoist_discriminant,
365
+ ) {
294
366
return None;
295
367
}
296
368
}
297
369
Some(OptimizationData {
298
370
destination,
299
- child_place: *child_place ,
371
+ child_place,
300
372
child_ty,
301
- child_source: child_terminator.source_info,
373
+ child_source: *source_info,
374
+ need_hoist_discriminant,
302
375
})
303
376
}
304
377
@@ -307,31 +380,48 @@ fn verify_candidate_branch<'tcx>(
307
380
value: u128,
308
381
place: Place<'tcx>,
309
382
destination: BasicBlock,
383
+ need_hoist_discriminant: bool,
310
384
) -> bool {
311
- // In order for the optimization to be correct, the branch must...
312
- // ...have exactly one statement
313
- if let [statement] = branch.statements.as_slice()
314
- // ...assign the discriminant of `place` in that statement
315
- && let StatementKind::Assign(boxed) = &statement.kind
316
- && let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed
317
- && *from_place == place
318
- // ...make that assignment to a local
319
- && discr_place.projection.is_empty()
320
- // ...terminate on a `SwitchInt` that invalidates that local
321
- && let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } =
322
- &branch.terminator().kind
323
- && *switch_op == Operand::Move(*discr_place)
324
- // ...fall through to `destination` if the switch misses
325
- && destination == targets.otherwise()
326
- // ...have a branch for value `value`
327
- && let mut iter = targets.iter()
328
- && let Some((target_value, _)) = iter.next()
329
- && target_value == value
330
- // ...and have no more branches
331
- && iter.next().is_none()
332
- {
333
- true
385
+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
386
+ let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
387
+ return false;
388
+ };
389
+ if need_hoist_discriminant {
390
+ // If we need hoist discriminant, the branch must have exactly one statement.
391
+ let [statement] = branch.statements.as_slice() else {
392
+ return false;
393
+ };
394
+ // The statement must assign the discriminant of `place`.
395
+ let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
396
+ statement.kind
397
+ else {
398
+ return false;
399
+ };
400
+ if from_place != place {
401
+ return false;
402
+ }
403
+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
404
+ if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
405
+ return false;
406
+ }
334
407
} else {
335
- false
408
+ // If we don't need hoist discriminant, the branch must not have any statements.
409
+ if !branch.statements.is_empty() {
410
+ return false;
411
+ }
412
+ // The place on `SwitchInt` must be the same.
413
+ if *switch_op != Operand::Copy(place) {
414
+ return false;
415
+ }
336
416
}
417
+ // It must fall through to `destination` if the switch misses.
418
+ if destination != targets.otherwise() {
419
+ return false;
420
+ }
421
+ // It must have exactly one branch for value `value` and have no more branches.
422
+ let mut iter = targets.iter();
423
+ let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
424
+ return false;
425
+ };
426
+ target_value == value
337
427
}
0 commit comments