-
Notifications
You must be signed in to change notification settings - Fork 123
BitGenerator support #499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
BitGenerator support #499
Changes from 25 commits
06d6ce1
07e2416
b611943
d93a264
f52b2fa
05814d6
37d360e
eed5b19
6c1a89b
d1909d3
bde2553
a0b9ec5
ee32246
1be6838
2aa3d90
0258e6d
876001b
71ce8be
2de7072
016eb7a
1f7f37f
1d01c7a
c90176a
f49d3fa
a16846d
573d890
06bb693
663fa29
c6105c9
3a0aa92
a92861a
6dbb6dc
b102d20
e5e440e
e73e3a2
c6493df
2327f36
e5c6458
e8cd5e8
0868405
8667203
1fd7bb5
3913171
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"[rust]": { | ||
"editor.defaultFormatter": "rust-lang.rust-analyzer", | ||
"editor.formatOnSave": true, | ||
}, | ||
"rust-analyzer.cargo.features": "all", | ||
} | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
use std::ffi::c_void; | ||
|
||
#[repr(C)] | ||
#[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? | ||
pub struct npy_bitgen { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pub state: *mut c_void, | ||
pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil | ||
pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil | ||
pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil | ||
pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,264 @@ | ||
//! Safe interface for NumPy's random [`BitGenerator`][bg]. | ||
//! | ||
//! Using the patterns described in [“Extending `numpy.random`”][ext], | ||
//! you can generate random numbers without holding the GIL, | ||
//! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]: | ||
//! | ||
//! ``` | ||
//! use pyo3::prelude::*; | ||
//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; | ||
//! | ||
//! fn default_bit_gen<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
//! let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; | ||
//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?; | ||
//! Ok(bit_generator) | ||
//! } | ||
//! | ||
//! let random_number = Python::with_gil(|py| -> PyResult<_> { | ||
//! let mut bitgen = default_bit_gen(py)?.lock()?; | ||
//! // use bitgen without holding the GIL | ||
//! Ok(py.allow_threads(|| bitgen.next_uint64())) | ||
//! })?; | ||
//! # Ok::<(), PyErr>(()) | ||
//! ``` | ||
//! | ||
//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]: | ||
//! | ||
//! ``` | ||
//! # use pyo3::prelude::*; | ||
//! use rand::Rng as _; | ||
//! # use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; | ||
//! # // TODO: reuse function definition from above? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like there should be a convenient way to get this. I'm thinking about something like impl PyBitGenerator {
fn new(py: Python<'_>) -> PyResult<Bound<..>>;
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are many implementations, we’d have to cover all of them. I’d rather leave this minimal until this PR is mostly done. |
||
//! # fn default_bit_gen<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
//! # let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; | ||
//! # let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?; | ||
//! # Ok(bit_generator) | ||
//! # } | ||
//! | ||
//! Python::with_gil(|py| -> PyResult<_> { | ||
//! let mut bitgen = default_bit_gen(py)?.lock()?; | ||
//! if bitgen.random_ratio(1, 1_000_000) { | ||
//! println!("a sure thing"); | ||
//! } | ||
//! Ok(()) | ||
//! })?; | ||
//! # Ok::<(), PyErr>(()) | ||
//! ``` | ||
//! | ||
//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html | ||
//! [ext]: https://numpy.org/doc/stable/reference/random/extending.html | ||
|
||
use std::ptr::NonNull; | ||
|
||
use pyo3::{ | ||
exceptions::PyRuntimeError, | ||
ffi, | ||
prelude::*, | ||
sync::GILOnceCell, | ||
types::{DerefToPyAny, PyCapsule, PyType}, | ||
PyTypeInfo, | ||
}; | ||
|
||
use crate::npyffi::npy_bitgen; | ||
|
||
/// Wrapper for [`np.random.BitGenerator`][bg]. | ||
/// | ||
/// See also [`PyBitGeneratorMethods`]. | ||
/// | ||
/// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html | ||
#[repr(transparent)] | ||
pub struct PyBitGenerator(PyAny); | ||
|
||
impl DerefToPyAny for PyBitGenerator {} | ||
|
||
unsafe impl PyTypeInfo for PyBitGenerator { | ||
const NAME: &'static str = "PyBitGenerator"; | ||
const MODULE: Option<&'static str> = Some("numpy.random"); | ||
|
||
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { | ||
const CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new(); | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let cls = CLS | ||
.get_or_try_init::<_, PyErr>(py, || { | ||
Ok(py | ||
.import("numpy.random")? | ||
.getattr("BitGenerator")? | ||
.downcast_into::<PyType>()? | ||
.unbind()) | ||
}) | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.expect("Failed to get BitGenerator type object") | ||
.clone_ref(py) | ||
.into_bound(py); | ||
cls.as_type_ptr() | ||
} | ||
} | ||
|
||
/// Methods for [`PyBitGenerator`]. | ||
pub trait PyBitGeneratorMethods<'py> { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/// Acquire a lock on the BitGenerator to allow calling its methods in. | ||
fn lock(&self) -> PyResult<PyBitGeneratorGuard<'py>>; | ||
} | ||
|
||
impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { | ||
fn lock(&self) -> PyResult<PyBitGeneratorGuard<'py>> { | ||
let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?; | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let lock = self.getattr("lock")?; | ||
if lock.call_method0("locked")?.extract()? { | ||
return Err(PyRuntimeError::new_err("BitGenerator is already locked")); | ||
} | ||
lock.call_method0("acquire")?; | ||
|
||
assert_eq!(capsule.name()?, Some(c"BitGenerator")); | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let ptr = capsule.pointer() as *mut npy_bitgen; | ||
let non_null = match NonNull::new(ptr) { | ||
Some(non_null) => non_null, | ||
None => { | ||
lock.call_method0("release")?; | ||
return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); | ||
} | ||
}; | ||
Ok(PyBitGeneratorGuard { | ||
raw_bitgen: non_null, | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_capsule: capsule.unbind(), | ||
lock: lock.unbind(), | ||
py: self.py(), | ||
}) | ||
} | ||
} | ||
|
||
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { | ||
type Error = PyErr; | ||
fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> { | ||
value.lock() | ||
} | ||
} | ||
|
||
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pub struct PyBitGeneratorGuard<'py> { | ||
raw_bitgen: NonNull<npy_bitgen>, | ||
/// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. | ||
_capsule: Py<PyCapsule>, | ||
/// This lock makes sure no other threads try to use the BitGenerator while we do. | ||
lock: Py<PyAny>, | ||
/// This should be an unsafe field (https://github.com/rust-lang/rust/issues/132922) | ||
/// | ||
/// SAFETY: only use this in `Drop::drop` (when we are sure the GIL is held). | ||
py: Python<'py>, | ||
} | ||
|
||
// SAFETY: we can’t have public APIs that access the Python objects, | ||
// only the `raw_bitgen` pointer. | ||
unsafe impl Send for PyBitGeneratorGuard<'_> {} | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
impl Drop for PyBitGeneratorGuard<'_> { | ||
fn drop(&mut self) { | ||
// ignore errors. This includes when `try_drop` was called manually | ||
let _ = self.lock.bind(self.py).call_method0("release"); | ||
} | ||
} | ||
|
||
// SAFETY: We hold the `BitGenerator.lock`, | ||
// so nothing apart from us is allowed to change its state. | ||
impl<'py> PyBitGeneratorGuard<'py> { | ||
/// Drop the lock manually before `Drop::drop` tries to do it (used for testing). | ||
#[allow(dead_code)] | ||
fn try_drop(self, py: Python<'py>) -> PyResult<()> { | ||
self.lock.bind(py).call_method0("release")?; | ||
Ok(()) | ||
} | ||
|
||
/// Returns the next random unsigned 64 bit integer. | ||
pub fn next_uint64(&mut self) -> u64 { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unsafe { | ||
let bitgen = *self.raw_bitgen.as_ptr(); | ||
(bitgen.next_uint64)(bitgen.state) | ||
} | ||
} | ||
/// Returns the next random unsigned 32 bit integer. | ||
pub fn next_uint32(&mut self) -> u32 { | ||
unsafe { | ||
let bitgen = *self.raw_bitgen.as_ptr(); | ||
(bitgen.next_uint32)(bitgen.state) | ||
} | ||
} | ||
/// Returns the next random double. | ||
pub fn next_double(&mut self) -> libc::c_double { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unsafe { | ||
let bitgen = *self.raw_bitgen.as_ptr(); | ||
(bitgen.next_double)(bitgen.state) | ||
} | ||
} | ||
/// Returns the next raw value (can be used for testing). | ||
pub fn next_raw(&mut self) -> u64 { | ||
unsafe { | ||
let bitgen = *self.raw_bitgen.as_ptr(); | ||
(bitgen.next_raw)(bitgen.state) | ||
} | ||
} | ||
} | ||
|
||
#[cfg(feature = "rand")] | ||
impl rand::RngCore for PyBitGeneratorGuard<'_> { | ||
fn next_u32(&mut self) -> u32 { | ||
self.next_uint32() | ||
} | ||
fn next_u64(&mut self) -> u64 { | ||
self.next_uint64() | ||
} | ||
fn fill_bytes(&mut self, dst: &mut [u8]) { | ||
rand::rand_core::impls::fill_bytes_via_next(self, dst) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
fn get_bit_generator<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; | ||
let bit_generator = default_rng | ||
.getattr("bit_generator")? | ||
.downcast_into::<PyBitGenerator>()?; | ||
Ok(bit_generator) | ||
} | ||
|
||
/// Test the primary use case: acquire the lock, release the GIL, then use the lock | ||
#[test] | ||
fn use_outside_gil() -> PyResult<()> { | ||
Python::with_gil(|py| { | ||
let mut bitgen = get_bit_generator(py)?.lock()?; | ||
py.allow_threads(|| { | ||
let _ = bitgen.next_raw(); | ||
}); | ||
assert!(bitgen.try_drop(py).is_ok()); | ||
Ok(()) | ||
}) | ||
} | ||
|
||
/// Test that the `rand::Rng` APIs work | ||
#[cfg(feature = "rand")] | ||
#[test] | ||
fn rand() -> PyResult<()> { | ||
use rand::Rng as _; | ||
|
||
Python::with_gil(|py| { | ||
let mut bitgen = get_bit_generator(py)?.lock()?; | ||
py.allow_threads(|| { | ||
assert!(bitgen.random_ratio(1, 1)); | ||
assert!(!bitgen.random_ratio(0, 1)); | ||
}); | ||
assert!(bitgen.try_drop(py).is_ok()); | ||
Ok(()) | ||
}) | ||
} | ||
|
||
#[test] | ||
fn double_lock_fails() -> PyResult<()> { | ||
Python::with_gil(|py| { | ||
let generator = get_bit_generator(py)?; | ||
let bitgen = generator.lock()?; | ||
assert!(generator.lock().is_err()); | ||
assert!(bitgen.try_drop(py).is_ok()); | ||
Ok(()) | ||
}) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will do when I’m done. I like working on multiple machines, and I don’t like re-doing settings for individual projects