Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit 9b56e20

Browse files
committed
Refactor the check_pointers interface and respect ZSTs in nullcheck
1 parent 8b8ffa1 commit 9b56e20

File tree

4 files changed

+132
-75
lines changed

4 files changed

+132
-75
lines changed

compiler/rustc_mir_transform/src/check_alignment.rs

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_middle::mir::*;
44
use rustc_middle::ty::{Ty, TyCtxt};
55
use rustc_session::Session;
66

7-
use crate::check_pointers::check_pointers;
7+
use crate::check_pointers::{PointerCheck, check_pointers};
88

99
pub(super) struct CheckAlignment;
1010

@@ -24,35 +24,33 @@ impl<'tcx> crate::MirPass<'tcx> for CheckAlignment {
2424
}
2525
}
2626

27+
/// Inserts the actual alignment check's logic. Returns a
28+
/// [AssertKind::MisalignedPointerDereference] on failure.
2729
fn insert_alignment_check<'tcx>(
2830
tcx: TyCtxt<'tcx>,
29-
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
30-
block_data: &mut BasicBlockData<'tcx>,
3131
pointer: Place<'tcx>,
3232
pointee_ty: Ty<'tcx>,
33+
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
34+
stmts: &mut Vec<Statement<'tcx>>,
3335
source_info: SourceInfo,
34-
new_block: BasicBlock,
35-
) {
36-
// Cast the pointer to a *const ()
36+
) -> PointerCheck<'tcx> {
37+
// Cast the pointer to a *const ().
3738
let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit);
3839
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
3940
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
40-
block_data
41-
.statements
41+
stmts
4242
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });
4343

44-
// Transmute the pointer to a usize (equivalent to `ptr.addr()`)
44+
// Transmute the pointer to a usize (equivalent to `ptr.addr()`).
4545
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
4646
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
47-
block_data
48-
.statements
49-
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
47+
stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
5048

5149
// Get the alignment of the pointee
5250
let alignment =
5351
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
5452
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty);
55-
block_data.statements.push(Statement {
53+
stmts.push(Statement {
5654
source_info,
5755
kind: StatementKind::Assign(Box::new((alignment, rvalue))),
5856
});
@@ -65,7 +63,7 @@ fn insert_alignment_check<'tcx>(
6563
user_ty: None,
6664
const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), tcx.types.usize),
6765
}));
68-
block_data.statements.push(Statement {
66+
stmts.push(Statement {
6967
source_info,
7068
kind: StatementKind::Assign(Box::new((
7169
alignment_mask,
@@ -76,7 +74,7 @@ fn insert_alignment_check<'tcx>(
7674
// BitAnd the alignment mask with the pointer
7775
let alignment_bits =
7876
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
79-
block_data.statements.push(Statement {
77+
stmts.push(Statement {
8078
source_info,
8179
kind: StatementKind::Assign(Box::new((
8280
alignment_bits,
@@ -94,29 +92,21 @@ fn insert_alignment_check<'tcx>(
9492
user_ty: None,
9593
const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
9694
}));
97-
block_data.statements.push(Statement {
95+
stmts.push(Statement {
9896
source_info,
9997
kind: StatementKind::Assign(Box::new((
10098
is_ok,
10199
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
102100
))),
103101
});
104102

105-
// Set this block's terminator to our assert, continuing to new_block if we pass
106-
block_data.terminator = Some(Terminator {
107-
source_info,
108-
kind: TerminatorKind::Assert {
109-
cond: Operand::Copy(is_ok),
110-
expected: true,
111-
target: new_block,
112-
msg: Box::new(AssertKind::MisalignedPointerDereference {
113-
required: Operand::Copy(alignment),
114-
found: Operand::Copy(addr),
115-
}),
116-
// This calls panic_misaligned_pointer_dereference, which is #[rustc_nounwind].
117-
// We never want to insert an unwind into unsafe code, because unwinding could
118-
// make a failing UB check turn into much worse UB when we start unwinding.
119-
unwind: UnwindAction::Unreachable,
120-
},
121-
});
103+
// Emit a check that asserts on the alignment and otherwise triggers a
104+
// AssertKind::MisalignedPointerDereference.
105+
PointerCheck {
106+
cond: Operand::Copy(is_ok),
107+
assert_kind: Box::new(AssertKind::MisalignedPointerDereference {
108+
required: Operand::Copy(alignment),
109+
found: Operand::Copy(addr),
110+
}),
111+
}
122112
}

compiler/rustc_mir_transform/src/check_null.rs

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_middle::mir::*;
44
use rustc_middle::ty::{Ty, TyCtxt};
55
use rustc_session::Session;
66

7-
use crate::check_pointers::check_pointers;
7+
use crate::check_pointers::{PointerCheck, check_pointers};
88

99
pub(super) struct CheckNull;
1010

@@ -20,54 +20,77 @@ impl<'tcx> crate::MirPass<'tcx> for CheckNull {
2020

2121
fn insert_null_check<'tcx>(
2222
tcx: TyCtxt<'tcx>,
23-
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
24-
block_data: &mut BasicBlockData<'tcx>,
2523
pointer: Place<'tcx>,
26-
_: Ty<'tcx>,
24+
pointee_ty: Ty<'tcx>,
25+
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
26+
stmts: &mut Vec<Statement<'tcx>>,
2727
source_info: SourceInfo,
28-
new_block: BasicBlock,
29-
) {
28+
) -> PointerCheck<'tcx> {
29+
// Cast the pointer to a *const ().
3030
let const_raw_ptr = Ty::new_imm_ptr(tcx, tcx.types.unit);
3131
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
3232
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
33-
block_data
34-
.statements
33+
stmts
3534
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });
3635

37-
// Transmute the pointer to a usize (equivalent to `ptr.addr()`)
36+
// Transmute the pointer to a usize (equivalent to `ptr.addr()`).
3837
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
3938
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
40-
block_data
41-
.statements
42-
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
39+
stmts.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });
4340

44-
// Check if the pointer is null.
45-
let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
41+
// Get size of the pointee (zero-sized reads and writes are allowed).
42+
let rvalue = Rvalue::NullaryOp(NullOp::SizeOf, pointee_ty);
43+
let sizeof_pointee =
44+
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
45+
stmts.push(Statement {
46+
source_info,
47+
kind: StatementKind::Assign(Box::new((sizeof_pointee, rvalue))),
48+
});
49+
50+
// Check that the pointee is a ZST.
4651
let zero = Operand::Constant(Box::new(ConstOperand {
4752
span: source_info.span,
4853
user_ty: None,
4954
const_: Const::Val(ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), tcx.types.usize),
5055
}));
51-
block_data.statements.push(Statement {
56+
let is_pointee_zst =
57+
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
58+
stmts.push(Statement {
59+
source_info,
60+
kind: StatementKind::Assign(Box::new((
61+
is_pointee_zst,
62+
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(sizeof_pointee), zero.clone()))),
63+
))),
64+
});
65+
66+
// Check if the pointer is null.
67+
let is_null = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
68+
stmts.push(Statement {
5269
source_info,
5370
kind: StatementKind::Assign(Box::new((
54-
is_ok,
71+
is_null,
5572
Rvalue::BinaryOp(BinOp::Ne, Box::new((Operand::Copy(addr), zero))),
5673
))),
5774
});
5875

59-
// Set this block's terminator to our assert, continuing to new_block if we pass
60-
block_data.terminator = Some(Terminator {
76+
// Check if the pointer is null or points to a ZST.
77+
let should_throw_exception =
78+
local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
79+
stmts.push(Statement {
6180
source_info,
62-
kind: TerminatorKind::Assert {
63-
cond: Operand::Copy(is_ok),
64-
expected: true,
65-
target: new_block,
66-
msg: Box::new(AssertKind::NullPointerDereference),
67-
// This calls panic_misaligned_pointer_dereference, which is #[rustc_nounwind].
68-
// We never want to insert an unwind into unsafe code, because unwinding could
69-
// make a failing UB check turn into much worse UB when we start unwinding.
70-
unwind: UnwindAction::Unreachable,
71-
},
81+
kind: StatementKind::Assign(Box::new((
82+
is_null,
83+
Rvalue::BinaryOp(
84+
BinOp::BitOr,
85+
Box::new((Operand::Copy(is_null), Operand::Copy(is_pointee_zst))),
86+
),
87+
))),
7288
});
89+
90+
// Emit a PointerCheck that asserts on the condition and otherwise triggers
91+
// a AssertKind::NullPointerDereference.
92+
PointerCheck {
93+
cond: Operand::Copy(should_throw_exception),
94+
assert_kind: Box::new(AssertKind::NullPointerDereference),
95+
}
7396
}

compiler/rustc_mir_transform/src/check_pointers.rs

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,40 @@ use rustc_middle::mir::*;
55
use rustc_middle::ty::{self, Ty, TyCtxt};
66
use tracing::{debug, trace};
77

8+
/// Details of a pointer check, the condition on which we decide whether to
9+
/// fail the assert and an [AssertKind] that defines the behavior on failure.
10+
pub(crate) struct PointerCheck<'tcx> {
11+
pub(crate) cond: Operand<'tcx>,
12+
pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>,
13+
}
14+
15+
/// Utility for adding a check for read/write on every sized, unsafe pointer.
16+
///
17+
/// Visits every read/write access to a [Sized], unsafe pointer and inserts a
18+
/// new basic block directly before the pointer access. Then calls `on_finding`
19+
/// to insert the actual logic for a pointer check (e.g. check for alignment).
20+
/// This utility takes care of the right order of blocks, the only thing a
21+
/// caller must do in `on_finding` is:
22+
/// - Append [Statement]s to `stmts`.
23+
/// - Append [LocalDecl]s to `local_decls`.
24+
/// - Return a [PointerCheck] that contains the condition and an [AssertKind].
25+
/// The AssertKind must be a panic with `#[rustc_nounwind]`.
26+
/// This utility will insert a terminator block that asserts on the condition
27+
/// and panics on failure.
828
pub(crate) fn check_pointers<'a, 'tcx, F>(
929
tcx: TyCtxt<'tcx>,
1030
body: &mut Body<'tcx>,
1131
excluded_pointees: &'a [Ty<'tcx>],
1232
on_finding: F,
1333
) where
1434
F: Fn(
15-
TyCtxt<'tcx>,
16-
&mut IndexVec<Local, LocalDecl<'tcx>>,
17-
&mut BasicBlockData<'tcx>,
18-
Place<'tcx>,
19-
Ty<'tcx>,
20-
SourceInfo,
21-
BasicBlock,
22-
),
35+
/* tcx: */ TyCtxt<'tcx>,
36+
/* pointer: */ Place<'tcx>,
37+
/* pointee_ty: */ Ty<'tcx>,
38+
/* local_decls: */ &mut IndexVec<Local, LocalDecl<'tcx>>,
39+
/* stmts: */ &mut Vec<Statement<'tcx>>,
40+
/* source_info: */ SourceInfo,
41+
) -> PointerCheck<'tcx>,
2342
{
2443
// This pass emits new panics. If for whatever reason we do not have a panic
2544
// implementation, running this pass may cause otherwise-valid code to not compile.
@@ -48,15 +67,33 @@ pub(crate) fn check_pointers<'a, 'tcx, F>(
4867
for (local, ty) in finder.into_found_pointers() {
4968
debug!("Inserting check for {:?}", ty);
5069
let new_block = split_block(basic_blocks, location);
51-
on_finding(
70+
71+
// Invoke `on_finding` which appends to `local_decls` and the
72+
// blocks statements. It returns information about the assert
73+
// we're performing in the Terminator.
74+
let block_data = &mut basic_blocks[block];
75+
let pointer_check = on_finding(
5276
tcx,
53-
local_decls,
54-
&mut basic_blocks[block],
5577
local,
5678
ty,
79+
local_decls,
80+
&mut block_data.statements,
5781
source_info,
58-
new_block,
5982
);
83+
block_data.terminator = Some(Terminator {
84+
source_info,
85+
kind: TerminatorKind::Assert {
86+
cond: pointer_check.cond,
87+
expected: true,
88+
target: new_block,
89+
msg: pointer_check.assert_kind,
90+
// This calls a panic function associated with the pointer check, which
91+
// is #[rustc_nounwind]. We never want to insert an unwind into unsafe
92+
// code, because unwinding could make a failing UB check turn into much
93+
// worse UB when we start unwinding.
94+
unwind: UnwindAction::Unreachable,
95+
},
96+
});
6097
}
6198
}
6299
}
@@ -128,7 +165,7 @@ impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
128165
return;
129166
}
130167

131-
// We don't need to look for str and slices, we already rejected unsized types above
168+
// We don't need to look for slices, we already rejected unsized types above.
132169
let element_ty = match pointee_ty.kind() {
133170
ty::Array(ty, _) => *ty,
134171
_ => pointee_ty,
@@ -150,7 +187,7 @@ fn split_block(
150187
) -> BasicBlock {
151188
let block_data = &mut basic_blocks[location.block];
152189

153-
// Drain every statement after this one and move the current terminator to a new basic block
190+
// Drain every statement after this one and move the current terminator to a new basic block.
154191
let new_block = BasicBlockData {
155192
statements: block_data.statements.split_off(location.statement_index),
156193
terminator: block_data.terminator.take(),
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
fn main() {
5+
let ptr: *const () = std::ptr::null();
6+
let _ptr = unsafe { *ptr };
7+
}

0 commit comments

Comments
 (0)