Skip to content

Commit 213127b

Browse files
authored
feat: Use type parameters to allow {get,set}regset to use different register set structs (#2373)
1 parent 395906e commit 213127b

File tree

2 files changed

+106
-38
lines changed

2 files changed

+106
-38
lines changed

src/sys/ptrace/linux.rs

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -172,21 +172,21 @@ libc_enum! {
172172
}
173173
}
174174

175+
#[cfg(all(
176+
target_os = "linux",
177+
target_env = "gnu",
178+
any(
179+
target_arch = "x86_64",
180+
target_arch = "x86",
181+
target_arch = "aarch64",
182+
target_arch = "riscv64",
183+
)
184+
))]
175185
libc_enum! {
176-
#[cfg(all(
177-
target_os = "linux",
178-
target_env = "gnu",
179-
any(
180-
target_arch = "x86_64",
181-
target_arch = "x86",
182-
target_arch = "aarch64",
183-
target_arch = "riscv64",
184-
)
185-
))]
186186
#[repr(i32)]
187-
/// Defining a specific register set, as used in [`getregset`] and [`setregset`].
187+
/// Defines a specific register set, as used in `PTRACE_GETREGSET` and `PTRACE_SETREGSET`.
188188
#[non_exhaustive]
189-
pub enum RegisterSet {
189+
pub enum RegisterSetValue {
190190
NT_PRSTATUS,
191191
NT_PRFPREG,
192192
NT_PRPSINFO,
@@ -195,6 +195,69 @@ libc_enum! {
195195
}
196196
}
197197

198+
#[cfg(all(
199+
target_os = "linux",
200+
target_env = "gnu",
201+
any(
202+
target_arch = "x86_64",
203+
target_arch = "x86",
204+
target_arch = "aarch64",
205+
target_arch = "riscv64",
206+
)
207+
))]
208+
/// Represents register set areas, such as general-purpose registers or
209+
/// floating-point registers.
210+
///
211+
/// # Safety
212+
///
213+
/// This trait is marked unsafe, since implementation of the trait must match
214+
/// ptrace's request `VALUE` and return data type `Regs`.
215+
pub unsafe trait RegisterSet {
216+
/// Corresponding type of registers in the kernel.
217+
const VALUE: RegisterSetValue;
218+
219+
/// Struct representing the register space.
220+
type Regs;
221+
}
222+
223+
#[cfg(all(
224+
target_os = "linux",
225+
target_env = "gnu",
226+
any(
227+
target_arch = "x86_64",
228+
target_arch = "x86",
229+
target_arch = "aarch64",
230+
target_arch = "riscv64",
231+
)
232+
))]
233+
/// Register sets used in [`getregset`] and [`setregset`]
234+
pub mod regset {
235+
use super::*;
236+
237+
#[derive(Debug, Clone, Copy)]
238+
/// General-purpose registers.
239+
pub struct NT_PRSTATUS;
240+
241+
unsafe impl RegisterSet for NT_PRSTATUS {
242+
const VALUE: RegisterSetValue = RegisterSetValue::NT_PRSTATUS;
243+
type Regs = user_regs_struct;
244+
}
245+
246+
#[derive(Debug, Clone, Copy)]
247+
/// Floating-point registers.
248+
pub struct NT_PRFPREG;
249+
250+
unsafe impl RegisterSet for NT_PRFPREG {
251+
const VALUE: RegisterSetValue = RegisterSetValue::NT_PRFPREG;
252+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
253+
type Regs = libc::user_fpregs_struct;
254+
#[cfg(target_arch = "aarch64")]
255+
type Regs = libc::user_fpsimd_struct;
256+
#[cfg(target_arch = "riscv64")]
257+
type Regs = libc::__riscv_mc_d_ext_state;
258+
}
259+
}
260+
198261
libc_bitflags! {
199262
/// Ptrace options used in conjunction with the PTRACE_SETOPTIONS request.
200263
/// See `man ptrace` for more details.
@@ -275,7 +338,7 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
275338
any(target_arch = "aarch64", target_arch = "riscv64")
276339
))]
277340
pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
278-
getregset(pid, RegisterSet::NT_PRSTATUS)
341+
getregset::<regset::NT_PRSTATUS>(pid)
279342
}
280343

281344
/// Get a particular set of user registers, as with `ptrace(PTRACE_GETREGSET, ...)`
@@ -289,18 +352,18 @@ pub fn getregs(pid: Pid) -> Result<user_regs_struct> {
289352
target_arch = "riscv64",
290353
)
291354
))]
292-
pub fn getregset(pid: Pid, set: RegisterSet) -> Result<user_regs_struct> {
355+
pub fn getregset<S: RegisterSet>(pid: Pid) -> Result<S::Regs> {
293356
let request = Request::PTRACE_GETREGSET;
294-
let mut data = mem::MaybeUninit::<user_regs_struct>::uninit();
357+
let mut data = mem::MaybeUninit::<S::Regs>::uninit();
295358
let mut iov = libc::iovec {
296359
iov_base: data.as_mut_ptr().cast(),
297-
iov_len: mem::size_of::<user_regs_struct>(),
360+
iov_len: mem::size_of::<S::Regs>(),
298361
};
299362
unsafe {
300363
ptrace_other(
301364
request,
302365
pid,
303-
set as i32 as AddressType,
366+
S::VALUE as i32 as AddressType,
304367
(&mut iov as *mut libc::iovec).cast(),
305368
)?;
306369
};
@@ -349,7 +412,7 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
349412
any(target_arch = "aarch64", target_arch = "riscv64")
350413
))]
351414
pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
352-
setregset(pid, RegisterSet::NT_PRSTATUS, regs)
415+
setregset::<regset::NT_PRSTATUS>(pid, regs)
353416
}
354417

355418
/// Set a particular set of user registers, as with `ptrace(PTRACE_SETREGSET, ...)`
@@ -363,20 +426,16 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> {
363426
target_arch = "riscv64",
364427
)
365428
))]
366-
pub fn setregset(
367-
pid: Pid,
368-
set: RegisterSet,
369-
mut regs: user_regs_struct,
370-
) -> Result<()> {
429+
pub fn setregset<S: RegisterSet>(pid: Pid, mut regs: S::Regs) -> Result<()> {
371430
let mut iov = libc::iovec {
372-
iov_base: (&mut regs as *mut user_regs_struct).cast(),
373-
iov_len: mem::size_of::<user_regs_struct>(),
431+
iov_base: (&mut regs as *mut S::Regs).cast(),
432+
iov_len: mem::size_of::<S::Regs>(),
374433
};
375434
unsafe {
376435
ptrace_other(
377436
Request::PTRACE_SETREGSET,
378437
pid,
379-
set as i32 as AddressType,
438+
S::VALUE as i32 as AddressType,
380439
(&mut iov as *mut libc::iovec).cast(),
381440
)?;
382441
}

test/sys/test_ptrace.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ fn test_ptrace_syscall() {
302302
))]
303303
#[test]
304304
fn test_ptrace_regsets() {
305-
use nix::sys::ptrace::{self, getregset, setregset, RegisterSet};
305+
use nix::sys::ptrace::{self, getregset, regset, setregset};
306306
use nix::sys::signal::*;
307307
use nix::sys::wait::{waitpid, WaitStatus};
308308
use nix::unistd::fork;
@@ -328,30 +328,39 @@ fn test_ptrace_regsets() {
328328
Ok(WaitStatus::Stopped(child, Signal::SIGTRAP))
329329
);
330330
let mut regstruct =
331-
getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
331+
getregset::<regset::NT_PRSTATUS>(child).unwrap();
332+
let mut fpregstruct =
333+
getregset::<regset::NT_PRFPREG>(child).unwrap();
332334

333335
#[cfg(target_arch = "x86_64")]
334-
let reg = &mut regstruct.r15;
336+
let (reg, fpreg) =
337+
(&mut regstruct.r15, &mut fpregstruct.st_space[5]);
335338
#[cfg(target_arch = "x86")]
336-
let reg = &mut regstruct.edx;
339+
let (reg, fpreg) =
340+
(&mut regstruct.edx, &mut fpregstruct.st_space[5]);
337341
#[cfg(target_arch = "aarch64")]
338-
let reg = &mut regstruct.regs[16];
342+
let (reg, fpreg) =
343+
(&mut regstruct.regs[16], &mut fpregstruct.vregs[5]);
339344
#[cfg(target_arch = "riscv64")]
340-
let reg = &mut regstruct.regs[16];
345+
let (reg, fpreg) = (&mut regstruct.t1, &mut fpregstruct.__f[5]);
341346

342347
*reg = 0xdeadbeefu32 as _;
343-
let _ = setregset(child, RegisterSet::NT_PRSTATUS, regstruct);
344-
regstruct = getregset(child, RegisterSet::NT_PRSTATUS).unwrap();
348+
*fpreg = 0xfeedfaceu32 as _;
349+
let _ = setregset::<regset::NT_PRSTATUS>(child, regstruct);
350+
regstruct = getregset::<regset::NT_PRSTATUS>(child).unwrap();
351+
let _ = setregset::<regset::NT_PRFPREG>(child, fpregstruct);
352+
fpregstruct = getregset::<regset::NT_PRFPREG>(child).unwrap();
345353

346354
#[cfg(target_arch = "x86_64")]
347-
let reg = regstruct.r15;
355+
let (reg, fpreg) = (regstruct.r15, fpregstruct.st_space[5]);
348356
#[cfg(target_arch = "x86")]
349-
let reg = regstruct.edx;
357+
let (reg, fpreg) = (regstruct.edx, fpregstruct.st_space[5]);
350358
#[cfg(target_arch = "aarch64")]
351-
let reg = regstruct.regs[16];
359+
let (reg, fpreg) = (regstruct.regs[16], fpregstruct.vregs[5]);
352360
#[cfg(target_arch = "riscv64")]
353-
let reg = regstruct.regs[16];
361+
let (reg, fpreg) = (regstruct.t1, fpregstruct.__f[5]);
354362
assert_eq!(reg, 0xdeadbeefu32 as _);
363+
assert_eq!(fpreg, 0xfeedfaceu32 as _);
355364

356365
ptrace::cont(child, Some(Signal::SIGKILL)).unwrap();
357366
match waitpid(child, None) {

0 commit comments

Comments
 (0)