-
Notifications
You must be signed in to change notification settings - Fork 123
BitGenerator support #499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
flying-sheep
wants to merge
43
commits into
PyO3:main
Choose a base branch
from
flying-sheep:pa/bitgen
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
BitGenerator support #499
Changes from 15 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
06d6ce1
WIP bitgen
flying-sheep 07e2416
nonnull
flying-sheep b611943
fix and test
flying-sheep d93a264
cmt
flying-sheep f52b2fa
safer: don’t allow trying to get `BitGen` from any PyAny
flying-sheep 05814d6
less indirection
flying-sheep 37d360e
add tryfrom
flying-sheep eed5b19
implement rand
flying-sheep 6c1a89b
fmt
flying-sheep d1909d3
rename and deref
flying-sheep bde2553
order
flying-sheep a0b9ec5
make into lock
flying-sheep ee32246
docs
flying-sheep 1be6838
more docs
flying-sheep 2aa3d90
guard
flying-sheep 0258e6d
call_method0
flying-sheep 876001b
reaname test
flying-sheep 71ce8be
manually drop and capsule
flying-sheep 2de7072
remove useless test
flying-sheep 016eb7a
doctests
flying-sheep 1f7f37f
smaller
flying-sheep 1d01c7a
clarify where to release the GIL
flying-sheep c90176a
safety
flying-sheep f49d3fa
oops
flying-sheep a16846d
less unsafe
flying-sheep 573d890
add thread test
flying-sheep 06bb693
back to lock acquiring
flying-sheep 663fa29
docs
flying-sheep c6105c9
no copy/clone
flying-sheep 3a0aa92
rename to release
flying-sheep a92861a
remove lifetime
flying-sheep 6dbb6dc
static
flying-sheep b102d20
no mut ref conversion
flying-sheep e5e440e
disambiguate
flying-sheep e73e3a2
rand_core only
flying-sheep c6493df
rename bitgen type
flying-sheep 2327f36
c_str macro
flying-sheep e5c6458
intern strings
flying-sheep e8cd5e8
docs
flying-sheep 0868405
more doc
flying-sheep 8667203
clean up tests
flying-sheep 1fd7bb5
no let-else
flying-sheep 3913171
use GILOnceCell::import
flying-sheep File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"[rust]": { | ||
"editor.defaultFormatter": "rust-lang.rust-analyzer", | ||
"editor.formatOnSave": true, | ||
}, | ||
"rust-analyzer.cargo.features": "all", | ||
} | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
use std::ffi::c_void; | ||
|
||
#[repr(C)] | ||
#[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? | ||
pub struct npy_bitgen { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pub state: *mut c_void, | ||
pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil | ||
pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil | ||
pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil | ||
pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
//! Safe interface for NumPy's random [`BitGenerator`][bg]. | ||
//! | ||
//! Using the patterns described in [“Extending `numpy.random`”][ext], | ||
//! you can generate random numbers without holding the GIL, | ||
//! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]: | ||
//! | ||
//! ```rust | ||
//! use pyo3::prelude::*; | ||
//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; | ||
//! | ||
//! let mut bitgen = Python::with_gil(|py| -> PyResult<_> { | ||
//! let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?; | ||
//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::<PyBitGenerator>()?; | ||
//! bit_generator.lock() | ||
//! })?; | ||
//! let random_number = bitgen.next_u64(); | ||
//! ``` | ||
//! | ||
//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]: | ||
//! | ||
//! ```rust | ||
//! use rand::Rng as _; | ||
//! | ||
//! if bitgen.random_ratio(1, 1_000_000) { | ||
//! println!("a sure thing"); | ||
//! } | ||
//! ``` | ||
//! | ||
//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html | ||
//! [ext]: https://numpy.org/doc/stable/reference/random/extending.html | ||
|
||
use std::ptr::NonNull; | ||
|
||
use pyo3::{ | ||
exceptions::PyRuntimeError, | ||
ffi, | ||
prelude::*, | ||
sync::GILOnceCell, | ||
types::{DerefToPyAny, PyCapsule, PyType}, | ||
PyTypeInfo, | ||
}; | ||
|
||
use crate::npyffi::npy_bitgen; | ||
|
||
/// Wrapper for [`np.random.BitGenerator`][bg]. | ||
/// | ||
/// See also [`PyBitGeneratorMethods`]. | ||
/// | ||
/// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html | ||
#[repr(transparent)] | ||
pub struct PyBitGenerator(PyAny); | ||
|
||
impl DerefToPyAny for PyBitGenerator {} | ||
|
||
unsafe impl PyTypeInfo for PyBitGenerator { | ||
const NAME: &'static str = "PyBitGenerator"; | ||
const MODULE: Option<&'static str> = Some("numpy.random"); | ||
|
||
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { | ||
const CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new(); | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let cls = CLS | ||
.get_or_try_init::<_, PyErr>(py, || { | ||
Ok(py | ||
.import("numpy.random")? | ||
.getattr("BitGenerator")? | ||
.downcast_into::<PyType>()? | ||
.unbind()) | ||
}) | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.expect("Failed to get BitGenerator type object") | ||
.clone_ref(py) | ||
.into_bound(py); | ||
cls.as_type_ptr() | ||
} | ||
} | ||
|
||
/// Methods for [`PyBitGenerator`]. | ||
pub trait PyBitGeneratorMethods { | ||
/// Acquire a lock on the BitGenerator to allow calling its methods in. | ||
fn lock(&self) -> PyResult<PyBitGeneratorGuard>; | ||
} | ||
|
||
impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { | ||
fn lock(&self) -> PyResult<PyBitGeneratorGuard> { | ||
let capsule = self.getattr("capsule")?.downcast_into::<PyCapsule>()?; | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let lock = self.getattr("lock")?; | ||
if lock.getattr("locked")?.call0()?.extract()? { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Err(PyRuntimeError::new_err("BitGenerator is already locked")); | ||
} | ||
lock.getattr("acquire")?.call0()?; | ||
|
||
assert_eq!(capsule.name()?, Some(c"BitGenerator")); | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let ptr = capsule.pointer() as *mut npy_bitgen; | ||
let non_null = match NonNull::new(ptr) { | ||
Some(non_null) => non_null, | ||
None => { | ||
lock.getattr("release")?.call0()?; | ||
return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); | ||
} | ||
}; | ||
Ok(PyBitGeneratorGuard { | ||
raw_bitgen: non_null, | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lock: lock.unbind(), | ||
}) | ||
} | ||
} | ||
|
||
impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { | ||
type Error = PyErr; | ||
fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result<Self, Self::Error> { | ||
value.lock() | ||
} | ||
} | ||
|
||
/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pub struct PyBitGeneratorGuard { | ||
raw_bitgen: NonNull<npy_bitgen>, | ||
lock: Py<PyAny>, | ||
} | ||
|
||
// SAFETY: We hold the `BitGenerator.lock`, | ||
// so nothing apart from us is allowed to change its state. | ||
impl PyBitGeneratorGuard { | ||
/// Returns the next random unsigned 64 bit integer. | ||
pub fn next_uint64(&mut self) -> u64 { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unsafe { | ||
let bitgen = self.raw_bitgen.as_mut(); | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(bitgen.next_uint64)(bitgen.state) | ||
} | ||
} | ||
/// Returns the next random unsigned 32 bit integer. | ||
pub fn next_uint32(&mut self) -> u32 { | ||
unsafe { | ||
let bitgen = self.raw_bitgen.as_mut(); | ||
(bitgen.next_uint32)(bitgen.state) | ||
} | ||
} | ||
/// Returns the next random double. | ||
pub fn next_double(&mut self) -> libc::c_double { | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
unsafe { | ||
let bitgen = self.raw_bitgen.as_mut(); | ||
(bitgen.next_double)(bitgen.state) | ||
} | ||
} | ||
/// Returns the next raw value (can be used for testing). | ||
pub fn next_raw(&mut self) -> u64 { | ||
unsafe { | ||
let bitgen = self.raw_bitgen.as_mut(); | ||
(bitgen.next_raw)(bitgen.state) | ||
} | ||
} | ||
} | ||
|
||
impl Drop for PyBitGeneratorGuard { | ||
fn drop(&mut self) { | ||
let r = Python::with_gil(|py| -> PyResult<()> { | ||
self.lock.bind(py).getattr("release")?.call0()?; | ||
Ok(()) | ||
}); | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if let Err(e) = r { | ||
eprintln!("Failed to release BitGenerator lock: {e}"); | ||
} | ||
flying-sheep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
#[cfg(feature = "rand")] | ||
impl rand::RngCore for PyBitGeneratorGuard { | ||
fn next_u32(&mut self) -> u32 { | ||
self.next_uint32() | ||
} | ||
fn next_u64(&mut self) -> u64 { | ||
self.next_uint64() | ||
} | ||
fn fill_bytes(&mut self, dst: &mut [u8]) { | ||
rand::rand_core::impls::fill_bytes_via_next(self, dst) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
fn get_bit_generator<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyBitGenerator>> { | ||
let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?; | ||
let bit_generator = default_rng | ||
.getattr("bit_generator")? | ||
.downcast_into::<PyBitGenerator>()?; | ||
Ok(bit_generator) | ||
} | ||
|
||
#[test] | ||
fn bitgen() -> PyResult<()> { | ||
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; | ||
let _ = bitgen.next_raw(); | ||
std::mem::drop(bitgen); | ||
Ok(()) | ||
} | ||
|
||
/// Test that the `rand::Rng` APIs work | ||
#[cfg(feature = "rand")] | ||
#[test] | ||
fn rand() -> PyResult<()> { | ||
use rand::Rng as _; | ||
|
||
let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; | ||
assert!(bitgen.random_ratio(1, 1)); | ||
assert!(!bitgen.random_ratio(0, 1)); | ||
std::mem::drop(bitgen); | ||
Ok(()) | ||
} | ||
|
||
/// Test that releasing the lock works while holding the GIL | ||
#[test] | ||
fn unlock_with_held_gil() -> PyResult<()> { | ||
Python::with_gil(|py| { | ||
let generator = get_bit_generator(py)?; | ||
let mut bitgen = generator.lock()?; | ||
let _ = bitgen.next_raw(); | ||
std::mem::drop(bitgen); | ||
assert!(!generator | ||
.getattr("lock")? | ||
.getattr("locked")? | ||
.call0()? | ||
.extract()?); | ||
Ok(()) | ||
}) | ||
} | ||
|
||
#[test] | ||
fn double_lock_fails() -> PyResult<()> { | ||
Python::with_gil(|py| { | ||
let generator = get_bit_generator(py)?; | ||
let d1 = generator.lock()?; | ||
assert!(generator.lock().is_err()); | ||
std::mem::drop(d1); | ||
Ok(()) | ||
}) | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will do when I’m done. I like working on multiple machines, and I don’t like re-doing settings for individual projects