From 06d6ce1fd8414e7260f3c488835c174a434dc7eb Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 6 Jun 2025 18:28:54 +0200 Subject: [PATCH 01/46] WIP bitgen --- .vscode/settings.json | 3 +++ examples/simple/src/lib.rs | 2 +- src/lib.rs | 1 + src/npyffi/mod.rs | 2 ++ src/npyffi/objects.rs | 2 +- src/npyffi/random.rs | 19 +++++++++++++++++++ src/random.rs | 35 +++++++++++++++++++++++++++++++++++ 7 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 src/npyffi/random.rs create mode 100644 src/random.rs diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..7b2b22c3a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "rust-analyzer.cargo.features": "all" +} \ No newline at end of file diff --git a/examples/simple/src/lib.rs b/examples/simple/src/lib.rs index 3bb29e3e1..b9cb28907 100644 --- a/examples/simple/src/lib.rs +++ b/examples/simple/src/lib.rs @@ -113,7 +113,7 @@ fn rust_ext<'py>(m: &Bound<'py, PyModule>) -> PyResult<()> { // This crate follows a strongly-typed approach to wrapping NumPy arrays // while Python API are often expected to work with multiple element types. // - // That kind of limited polymorphis can be recovered by accepting an enumerated type + // That kind of limited polymorphism can be recovered by accepting an enumerated type // covering the supported element types and dispatching into a generic implementation. #[derive(FromPyObject)] enum SupportedArray<'py> { diff --git a/src/lib.rs b/src/lib.rs index 195465022..219f7efe1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,6 +79,7 @@ pub mod datetime; mod dtype; mod error; pub mod npyffi; +pub mod random; mod slice_container; mod strings; mod sum_products; diff --git a/src/npyffi/mod.rs b/src/npyffi/mod.rs index bf846f8e2..efd214a80 100644 --- a/src/npyffi/mod.rs +++ b/src/npyffi/mod.rs @@ -94,11 +94,13 @@ macro_rules! impl_api { pub mod array; pub mod flags; pub mod objects; +pub mod random; pub mod types; pub mod ufunc; pub use self::array::*; pub use self::flags::*; pub use self::objects::*; +pub use self::random::*; pub use self::types::*; pub use self::ufunc::*; diff --git a/src/npyffi/objects.rs b/src/npyffi/objects.rs index d28e88b7e..5a00539fa 100644 --- a/src/npyffi/objects.rs +++ b/src/npyffi/objects.rs @@ -1,4 +1,4 @@ -//! Low-Lebel binding for NumPy C API C-objects +//! Low-Level binding for NumPy C API C-objects //! //! #![allow(non_camel_case_types)] diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs new file mode 100644 index 000000000..96259d2f1 --- /dev/null +++ b/src/npyffi/random.rs @@ -0,0 +1,19 @@ +use std::{ffi::c_void, ptr::NonNull}; + +use pyo3::{prelude::*, types::PyCapsule}; + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct npy_bitgen { + pub state: *mut c_void, + pub next_uint64: NonNull super::npy_uint64>, //nogil + pub next_uint32: NonNull super::npy_uint32>, //nogil + pub next_double: NonNull libc::c_double>, //nogil + pub next_raw: NonNull super::npy_uint64>, //nogil +} + +pub fn get_bitgen_api<'py>(bitgen: Bound<'py, PyAny>) -> PyResult<*mut npy_bitgen> { + let capsule = bitgen.getattr("capsule")?.downcast_into::()?; + assert_eq!(capsule.name()?, Some(c"BitGenerator")); + Ok(capsule.pointer() as *mut npy_bitgen) +} diff --git a/src/random.rs b/src/random.rs new file mode 100644 index 000000000..868132fd5 --- /dev/null +++ b/src/random.rs @@ -0,0 +1,35 @@ +//! Safe interface for NumPy's random [`BitGenerator`][] +//! +//! `BitGenerator`: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html + +use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::PyType, PyTypeInfo}; + +use crate::npyffi::get_bitgen_api; + +///! Wrapper for NumPy's random [`BitGenerator`][] +/// +///! [BitGenerator]: https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.BitGenerator.html +#[repr(transparent)] +pub struct BitGenerator(PyAny); + +unsafe impl PyTypeInfo for BitGenerator { + const NAME: &'static str = "BitGenerator"; + const MODULE: Option<&'static str> = Some("numpy.random"); + + fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { + const CLS: GILOnceCell> = GILOnceCell::new(); + let cls = CLS + .get_or_try_init::<_, PyErr>(py, || { + Ok(py + .import("numpy.random")? + .getattr("BitGenerator")? + .downcast_into::()? + .unbind()) + }) + .expect("Failed to get BitGenerator type object") + .clone_ref(py) + .into_bound(py); + cls.as_type_ptr() + } +} + From 07e24161fbafc887adc5ff4aed673045d2100350 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 6 Jun 2025 19:11:35 +0200 Subject: [PATCH 02/46] nonnull --- src/npyffi/random.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index 96259d2f1..ccb456713 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -1,6 +1,6 @@ use std::{ffi::c_void, ptr::NonNull}; -use pyo3::{prelude::*, types::PyCapsule}; +use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule}; #[repr(C)] #[derive(Debug, Clone, Copy)] @@ -12,8 +12,9 @@ pub struct npy_bitgen { pub next_raw: NonNull super::npy_uint64>, //nogil } -pub fn get_bitgen_api<'py>(bitgen: Bound<'py, PyAny>) -> PyResult<*mut npy_bitgen> { +pub fn get_bitgen_api<'py>(bitgen: Bound<'py, PyAny>) -> PyResult> { let capsule = bitgen.getattr("capsule")?.downcast_into::()?; assert_eq!(capsule.name()?, Some(c"BitGenerator")); - Ok(capsule.pointer() as *mut npy_bitgen) + let ptr = capsule.pointer() as *mut npy_bitgen; + NonNull::new(ptr).ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule")) } From b61194376f27f43b5c27bfe94aa633cec3b76e66 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 6 Jun 2025 19:36:28 +0200 Subject: [PATCH 03/46] fix and test --- src/npyffi/random.rs | 10 ++++---- src/random.rs | 57 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index ccb456713..43efb9072 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -6,13 +6,13 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule}; #[derive(Debug, Clone, Copy)] pub struct npy_bitgen { pub state: *mut c_void, - pub next_uint64: NonNull super::npy_uint64>, //nogil - pub next_uint32: NonNull super::npy_uint32>, //nogil - pub next_double: NonNull libc::c_double>, //nogil - pub next_raw: NonNull super::npy_uint64>, //nogil + pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil + pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil + pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil + pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil } -pub fn get_bitgen_api<'py>(bitgen: Bound<'py, PyAny>) -> PyResult> { +pub fn get_bitgen_api<'py>(bitgen: &Bound<'py, PyAny>) -> PyResult> { let capsule = bitgen.getattr("capsule")?.downcast_into::()?; assert_eq!(capsule.name()?, Some(c"BitGenerator")); let ptr = capsule.pointer() as *mut npy_bitgen; diff --git a/src/random.rs b/src/random.rs index 868132fd5..cac5f3657 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,14 +1,12 @@ -//! Safe interface for NumPy's random [`BitGenerator`][] -//! -//! `BitGenerator`: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html +//! Safe interface for NumPy's random [`BitGenerator`] use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::PyType, PyTypeInfo}; use crate::npyffi::get_bitgen_api; -///! Wrapper for NumPy's random [`BitGenerator`][] -/// -///! [BitGenerator]: https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.BitGenerator.html +///! Wrapper for NumPy's random [`BitGenerator`][bg] +///! +///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html #[repr(transparent)] pub struct BitGenerator(PyAny); @@ -33,3 +31,50 @@ unsafe impl PyTypeInfo for BitGenerator { } } +/// Methods for [`BitGenerator`] +pub trait BitGeneratorMethods { + /// Returns the next random unsigned 64 bit integer + fn next_uint64(&self) -> u64; + /// Returns the next random unsigned 32 bit integer + fn next_uint32(&self) -> u32; + /// Returns the next random double + fn next_double(&self) -> libc::c_double; + /// Returns the next raw value (can be used for testing) + fn next_raw(&self) -> u64; +} + +// TODO: cache npy_bitgen pointer +impl<'py> BitGeneratorMethods for Bound<'py, BitGenerator> { + fn next_uint64(&self) -> u64 { + todo!() + } + fn next_uint32(&self) -> u32 { + todo!() + } + fn next_double(&self) -> libc::c_double { + todo!() + } + fn next_raw(&self) -> u64 { + let mut api = get_bitgen_api(self.as_any()).expect("Could not get bitgen"); + unsafe { + let api = api.as_mut(); + (api.next_raw)(api.state) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bitgen() -> PyResult<()> { + Python::with_gil(|py| { + let default_rng = py.import("numpy.random")?.getattr("default_rng")?; + let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::()?; + let res = bitgen.next_raw(); + dbg!(res); + Ok(()) + }) + } +} From d93a2643481b07c63d996ac68e2340b568f18a21 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 6 Jun 2025 19:45:53 +0200 Subject: [PATCH 04/46] cmt --- src/npyffi/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index 43efb9072..818cfc299 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -3,7 +3,7 @@ use std::{ffi::c_void, ptr::NonNull}; use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule}; #[repr(C)] -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? pub struct npy_bitgen { pub state: *mut c_void, pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil From f52b2fa38b81ab3bf5fd7251c282df87801e6e2c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sat, 7 Jun 2025 12:25:32 +0200 Subject: [PATCH 05/46] =?UTF-8?q?safer:=20don=E2=80=99t=20allow=20trying?= =?UTF-8?q?=20to=20get=20`BitGen`=20from=20any=20PyAny?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/npyffi/random.rs | 10 +--------- src/random.rs | 45 +++++++++++++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index 818cfc299..dc5af9070 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -1,6 +1,5 @@ -use std::{ffi::c_void, ptr::NonNull}; +use std::ffi::c_void; -use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule}; #[repr(C)] #[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? @@ -11,10 +10,3 @@ pub struct npy_bitgen { pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil } - -pub fn get_bitgen_api<'py>(bitgen: &Bound<'py, PyAny>) -> PyResult> { - let capsule = bitgen.getattr("capsule")?.downcast_into::()?; - assert_eq!(capsule.name()?, Some(c"BitGenerator")); - let ptr = capsule.pointer() as *mut npy_bitgen; - NonNull::new(ptr).ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule")) -} diff --git a/src/random.rs b/src/random.rs index cac5f3657..219e56ebc 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,10 +1,10 @@ //! Safe interface for NumPy's random [`BitGenerator`] -use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::PyType, PyTypeInfo}; +use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::{PyCapsule, PyType}, PyTypeInfo, exceptions::PyRuntimeError}; -use crate::npyffi::get_bitgen_api; +use crate::npyffi::npy_bitgen; -///! Wrapper for NumPy's random [`BitGenerator`][bg] +///! Wrapper for [`np.random.BitGenerator`][bg] ///! ///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html #[repr(transparent)] @@ -32,7 +32,27 @@ unsafe impl PyTypeInfo for BitGenerator { } /// Methods for [`BitGenerator`] -pub trait BitGeneratorMethods { +pub trait BitGeneratorMethods<'py> { + /// Returns a new [`BitGen`] + fn bit_gen(&self) -> PyResult>; +} + +impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> { + fn bit_gen(&self) -> PyResult> { + let capsule = self.as_any().getattr("capsule")?.downcast_into::()?; + assert_eq!(capsule.name()?, Some(c"BitGenerator")); + let ptr = capsule.pointer() as *mut npy_bitgen; + // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self` + let ref_ = unsafe { ptr.as_mut::<'py>() }.ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))?; + Ok(BitGen(ref_)) + } +} + +/// Wrapper for [`npy_bitgen`] +pub struct BitGen<'a>(&'a mut npy_bitgen); + +/// Methods for [`BitGen`] +pub trait BitGenMethods { /// Returns the next random unsigned 64 bit integer fn next_uint64(&self) -> u64; /// Returns the next random unsigned 32 bit integer @@ -43,23 +63,18 @@ pub trait BitGeneratorMethods { fn next_raw(&self) -> u64; } -// TODO: cache npy_bitgen pointer -impl<'py> BitGeneratorMethods for Bound<'py, BitGenerator> { +impl<'py> BitGenMethods for BitGen<'py> { fn next_uint64(&self) -> u64 { - todo!() + unsafe { (self.0.next_uint64)(self.0.state) } } fn next_uint32(&self) -> u32 { - todo!() + unsafe { (self.0.next_uint32)(self.0.state) } } fn next_double(&self) -> libc::c_double { - todo!() + unsafe { (self.0.next_double)(self.0.state) } } fn next_raw(&self) -> u64 { - let mut api = get_bitgen_api(self.as_any()).expect("Could not get bitgen"); - unsafe { - let api = api.as_mut(); - (api.next_raw)(api.state) - } + unsafe { (self.0.next_raw)(self.0.state) } } } @@ -71,7 +86,7 @@ mod tests { fn test_bitgen() -> PyResult<()> { Python::with_gil(|py| { let default_rng = py.import("numpy.random")?.getattr("default_rng")?; - let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::()?; + let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::()?.bit_gen()?; let res = bitgen.next_raw(); dbg!(res); Ok(()) From 05814d6e7754b60765355d4fad951c9adde952e2 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sat, 7 Jun 2025 12:27:00 +0200 Subject: [PATCH 06/46] less indirection --- src/random.rs | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/random.rs b/src/random.rs index 219e56ebc..45e1eb32d 100644 --- a/src/random.rs +++ b/src/random.rs @@ -51,29 +51,21 @@ impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> { /// Wrapper for [`npy_bitgen`] pub struct BitGen<'a>(&'a mut npy_bitgen); -/// Methods for [`BitGen`] -pub trait BitGenMethods { +impl<'py> BitGen<'py> { /// Returns the next random unsigned 64 bit integer - fn next_uint64(&self) -> u64; - /// Returns the next random unsigned 32 bit integer - fn next_uint32(&self) -> u32; - /// Returns the next random double - fn next_double(&self) -> libc::c_double; - /// Returns the next raw value (can be used for testing) - fn next_raw(&self) -> u64; -} - -impl<'py> BitGenMethods for BitGen<'py> { - fn next_uint64(&self) -> u64 { + pub fn next_uint64(&self) -> u64 { unsafe { (self.0.next_uint64)(self.0.state) } } - fn next_uint32(&self) -> u32 { + /// Returns the next random unsigned 32 bit integer + pub fn next_uint32(&self) -> u32 { unsafe { (self.0.next_uint32)(self.0.state) } } - fn next_double(&self) -> libc::c_double { + /// Returns the next random double + pub fn next_double(&self) -> libc::c_double { unsafe { (self.0.next_double)(self.0.state) } } - fn next_raw(&self) -> u64 { + /// Returns the next raw value (can be used for testing) + pub fn next_raw(&self) -> u64 { unsafe { (self.0.next_raw)(self.0.state) } } } From 37d360ebffca4eb2f8ac3196cb44baa2478239d8 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sat, 7 Jun 2025 12:28:40 +0200 Subject: [PATCH 07/46] add tryfrom --- src/random.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/random.rs b/src/random.rs index 45e1eb32d..2f6892ebc 100644 --- a/src/random.rs +++ b/src/random.rs @@ -48,6 +48,13 @@ impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> { } } +impl<'py> TryFrom<&Bound<'py, BitGenerator>> for BitGen<'py> { + type Error = PyErr; + fn try_from(value: &Bound<'py, BitGenerator>) -> Result { + value.bit_gen() + } +} + /// Wrapper for [`npy_bitgen`] pub struct BitGen<'a>(&'a mut npy_bitgen); From eed5b1999d624cc87832686a57fc29a8e25ffe00 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sat, 7 Jun 2025 19:35:03 +0200 Subject: [PATCH 08/46] implement rand --- Cargo.toml | 1 + src/random.rs | 41 +++++++++++++++++++++++++++++++++++------ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ec359edd3..636d17288 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ num-integer = "0.1" num-traits = "0.2" ndarray = ">= 0.15, < 0.17" pyo3 = { version = "0.25.0", default-features = false, features = ["macros"] } +rand = { version = "0.9.1", default-features = false, optional = true } rustc-hash = "2.0" [dev-dependencies] diff --git a/src/random.rs b/src/random.rs index 2f6892ebc..dd0dcb1d0 100644 --- a/src/random.rs +++ b/src/random.rs @@ -77,18 +77,47 @@ impl<'py> BitGen<'py> { } } +#[cfg(feature = "rand")] +impl rand::RngCore for BitGen<'_> { + fn next_u32(&mut self) -> u32 { + self.next_uint32() + } + fn next_u64(&mut self) -> u64 { + self.next_uint64() + } + fn fill_bytes(&mut self, dst: &mut [u8]) { + rand::rand_core::impls::fill_bytes_via_next(self, dst) + } +} + #[cfg(test)] mod tests { use super::*; + fn get_bit_generator<'py>(py: Python<'py>) -> PyResult> { + let default_rng = py.import("numpy.random")?.getattr("default_rng")?; + let bit_generator = default_rng.call0()?.getattr("bit_generator")?.downcast_into::()?; + Ok(bit_generator) + } + #[test] - fn test_bitgen() -> PyResult<()> { + fn bitgen() -> PyResult<()> { Python::with_gil(|py| { - let default_rng = py.import("numpy.random")?.getattr("default_rng")?; - let bitgen = default_rng.call0()?.getattr("bit_generator")?.downcast_into::()?.bit_gen()?; - let res = bitgen.next_raw(); - dbg!(res); + let bitgen = get_bit_generator(py)?.bit_gen()?; + let _ = bitgen.next_raw(); Ok(()) }) - } + } + + #[cfg(feature = "rand")] + #[test] + fn rand() -> PyResult<()> { + use rand::Rng as _; + + Python::with_gil(|py| { + let mut bitgen = get_bit_generator(py)?.bit_gen()?; + let _ = bitgen.random_ratio(2, 3); + Ok(()) + }) + } } From 6c1a89b9601084283f4959b6f7d984ea0e136391 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sat, 7 Jun 2025 19:36:02 +0200 Subject: [PATCH 09/46] fmt --- .vscode/settings.json | 6 +++++- src/npyffi/random.rs | 5 ++--- src/random.rs | 28 +++++++++++++++++++++------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 7b2b22c3a..1b5f3234a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,7 @@ { - "rust-analyzer.cargo.features": "all" + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer", + "editor.formatOnSave": true, + }, + "rust-analyzer.cargo.features": "all", } \ No newline at end of file diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index dc5af9070..6c401ac7a 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -1,12 +1,11 @@ use std::ffi::c_void; - #[repr(C)] #[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? pub struct npy_bitgen { pub state: *mut c_void, pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil - pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil - pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil + pub next_double: unsafe extern "C" fn(*mut c_void) -> libc::c_double, //nogil + pub next_raw: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil } diff --git a/src/random.rs b/src/random.rs index dd0dcb1d0..ec755f690 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,6 +1,13 @@ //! Safe interface for NumPy's random [`BitGenerator`] -use pyo3::{ffi, prelude::*, sync::GILOnceCell, types::{PyCapsule, PyType}, PyTypeInfo, exceptions::PyRuntimeError}; +use pyo3::{ + exceptions::PyRuntimeError, + ffi, + prelude::*, + sync::GILOnceCell, + types::{PyCapsule, PyType}, + PyTypeInfo, +}; use crate::npyffi::npy_bitgen; @@ -39,11 +46,15 @@ pub trait BitGeneratorMethods<'py> { impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> { fn bit_gen(&self) -> PyResult> { - let capsule = self.as_any().getattr("capsule")?.downcast_into::()?; + let capsule = self + .as_any() + .getattr("capsule")? + .downcast_into::()?; assert_eq!(capsule.name()?, Some(c"BitGenerator")); let ptr = capsule.pointer() as *mut npy_bitgen; // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self` - let ref_ = unsafe { ptr.as_mut::<'py>() }.ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))?; + let ref_ = unsafe { ptr.as_mut::<'py>() } + .ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))?; Ok(BitGen(ref_)) } } @@ -93,13 +104,16 @@ impl rand::RngCore for BitGen<'_> { #[cfg(test)] mod tests { use super::*; - + fn get_bit_generator<'py>(py: Python<'py>) -> PyResult> { let default_rng = py.import("numpy.random")?.getattr("default_rng")?; - let bit_generator = default_rng.call0()?.getattr("bit_generator")?.downcast_into::()?; + let bit_generator = default_rng + .call0()? + .getattr("bit_generator")? + .downcast_into::()?; Ok(bit_generator) } - + #[test] fn bitgen() -> PyResult<()> { Python::with_gil(|py| { @@ -113,7 +127,7 @@ mod tests { #[test] fn rand() -> PyResult<()> { use rand::Rng as _; - + Python::with_gil(|py| { let mut bitgen = get_bit_generator(py)?.bit_gen()?; let _ = bitgen.random_ratio(2, 3); From d1909d3558d7bcc68636cbf56cc168fcbc1b5b61 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 12:22:08 +0200 Subject: [PATCH 10/46] rename and deref --- src/random.rs | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/random.rs b/src/random.rs index ec755f690..28e472e69 100644 --- a/src/random.rs +++ b/src/random.rs @@ -5,7 +5,7 @@ use pyo3::{ ffi, prelude::*, sync::GILOnceCell, - types::{PyCapsule, PyType}, + types::{DerefToPyAny, PyCapsule, PyType}, PyTypeInfo, }; @@ -15,10 +15,10 @@ use crate::npyffi::npy_bitgen; ///! ///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html #[repr(transparent)] -pub struct BitGenerator(PyAny); +pub struct PyBitGenerator(PyAny); -unsafe impl PyTypeInfo for BitGenerator { - const NAME: &'static str = "BitGenerator"; +unsafe impl PyTypeInfo for PyBitGenerator { + const NAME: &'static str = "PyBitGenerator"; const MODULE: Option<&'static str> = Some("numpy.random"); fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { @@ -38,18 +38,17 @@ unsafe impl PyTypeInfo for BitGenerator { } } +impl DerefToPyAny for PyBitGenerator {} + /// Methods for [`BitGenerator`] pub trait BitGeneratorMethods<'py> { /// Returns a new [`BitGen`] fn bit_gen(&self) -> PyResult>; } -impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> { +impl<'py> BitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { fn bit_gen(&self) -> PyResult> { - let capsule = self - .as_any() - .getattr("capsule")? - .downcast_into::()?; + let capsule = self.getattr("capsule")?.downcast_into::()?; assert_eq!(capsule.name()?, Some(c"BitGenerator")); let ptr = capsule.pointer() as *mut npy_bitgen; // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self` @@ -59,9 +58,9 @@ impl<'py> BitGeneratorMethods<'py> for Bound<'py, BitGenerator> { } } -impl<'py> TryFrom<&Bound<'py, BitGenerator>> for BitGen<'py> { +impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for BitGen<'py> { type Error = PyErr; - fn try_from(value: &Bound<'py, BitGenerator>) -> Result { + fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result { value.bit_gen() } } @@ -105,12 +104,12 @@ impl rand::RngCore for BitGen<'_> { mod tests { use super::*; - fn get_bit_generator<'py>(py: Python<'py>) -> PyResult> { + fn get_bit_generator<'py>(py: Python<'py>) -> PyResult> { let default_rng = py.import("numpy.random")?.getattr("default_rng")?; let bit_generator = default_rng .call0()? .getattr("bit_generator")? - .downcast_into::()?; + .downcast_into::()?; Ok(bit_generator) } From bde2553e1b89bf99408b69120b90b05e20a7c5bc Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 12:23:28 +0200 Subject: [PATCH 11/46] order --- src/random.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/random.rs b/src/random.rs index 28e472e69..0f88abac0 100644 --- a/src/random.rs +++ b/src/random.rs @@ -17,6 +17,8 @@ use crate::npyffi::npy_bitgen; #[repr(transparent)] pub struct PyBitGenerator(PyAny); +impl DerefToPyAny for PyBitGenerator {} + unsafe impl PyTypeInfo for PyBitGenerator { const NAME: &'static str = "PyBitGenerator"; const MODULE: Option<&'static str> = Some("numpy.random"); @@ -38,8 +40,6 @@ unsafe impl PyTypeInfo for PyBitGenerator { } } -impl DerefToPyAny for PyBitGenerator {} - /// Methods for [`BitGenerator`] pub trait BitGeneratorMethods<'py> { /// Returns a new [`BitGen`] From a0b9ec574346751b9f33509ee79667a243ba4468 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 14:04:41 +0200 Subject: [PATCH 12/46] make into lock --- src/random.rs | 127 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 93 insertions(+), 34 deletions(-) diff --git a/src/random.rs b/src/random.rs index 0f88abac0..56eddcf8b 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,4 +1,8 @@ -//! Safe interface for NumPy's random [`BitGenerator`] +//! Safe interface for NumPy's random [`BitGenerator`][bg] +//! +//! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html + +use std::ptr::NonNull; use pyo3::{ exceptions::PyRuntimeError, @@ -12,8 +16,6 @@ use pyo3::{ use crate::npyffi::npy_bitgen; ///! Wrapper for [`np.random.BitGenerator`][bg] -///! -///! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html #[repr(transparent)] pub struct PyBitGenerator(PyAny); @@ -40,55 +42,90 @@ unsafe impl PyTypeInfo for PyBitGenerator { } } -/// Methods for [`BitGenerator`] -pub trait BitGeneratorMethods<'py> { - /// Returns a new [`BitGen`] - fn bit_gen(&self) -> PyResult>; +/// Methods for [`PyBitGenerator`] +pub trait BitGeneratorMethods { + /// Acquire a lock on the BitGenerator to allow calling its methods in + fn lock(&self) -> PyResult; } -impl<'py> BitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { - fn bit_gen(&self) -> PyResult> { +impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> { + fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; + let lock = self.getattr("lock")?; + if lock.getattr("locked")?.call0()?.extract()? { + return Err(PyRuntimeError::new_err("BitGenerator is already locked")); + } + lock.getattr("acquire")?.call0()?; + assert_eq!(capsule.name()?, Some(c"BitGenerator")); let ptr = capsule.pointer() as *mut npy_bitgen; - // SAFETY: the lifetime of `ptr` is derived from the lifetime of `self` - let ref_ = unsafe { ptr.as_mut::<'py>() } - .ok_or_else(|| PyRuntimeError::new_err("Invalid BitGenerator capsule"))?; - Ok(BitGen(ref_)) + let non_null = match NonNull::new(ptr) { + Some(non_null) => non_null, + None => { + lock.getattr("release")?.call0()?; + return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); + } + }; + Ok(PyBitGeneratorLock(non_null, lock.unbind())) } } -impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for BitGen<'py> { +impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock { type Error = PyErr; fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result { - value.bit_gen() + value.lock() } } -/// Wrapper for [`npy_bitgen`] -pub struct BitGen<'a>(&'a mut npy_bitgen); +/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL +pub struct PyBitGeneratorLock(NonNull, Py); -impl<'py> BitGen<'py> { +// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state +impl PyBitGeneratorLock { /// Returns the next random unsigned 64 bit integer - pub fn next_uint64(&self) -> u64 { - unsafe { (self.0.next_uint64)(self.0.state) } + pub fn next_uint64(&mut self) -> u64 { + unsafe { + let bitgen = self.0.as_mut(); + (bitgen.next_uint64)(bitgen.state) + } } /// Returns the next random unsigned 32 bit integer - pub fn next_uint32(&self) -> u32 { - unsafe { (self.0.next_uint32)(self.0.state) } + pub fn next_uint32(&mut self) -> u32 { + unsafe { + let bitgen = self.0.as_mut(); + (bitgen.next_uint32)(bitgen.state) + } } /// Returns the next random double - pub fn next_double(&self) -> libc::c_double { - unsafe { (self.0.next_double)(self.0.state) } + pub fn next_double(&mut self) -> libc::c_double { + unsafe { + let bitgen = self.0.as_mut(); + (bitgen.next_double)(bitgen.state) + } } /// Returns the next raw value (can be used for testing) - pub fn next_raw(&self) -> u64 { - unsafe { (self.0.next_raw)(self.0.state) } + pub fn next_raw(&mut self) -> u64 { + unsafe { + let bitgen = self.0.as_mut(); + (bitgen.next_raw)(bitgen.state) + } + } +} + +impl Drop for PyBitGeneratorLock { + fn drop(&mut self) { + let r = Python::with_gil(|py| -> PyResult<()> { + self.1.bind(py).getattr("release")?.call0()?; + Ok(()) + }); + if let Err(e) = r { + eprintln!("Failed to release BitGenerator lock: {e}"); + } } } #[cfg(feature = "rand")] -impl rand::RngCore for BitGen<'_> { +impl rand::RngCore for PyBitGeneratorLock { fn next_u32(&mut self) -> u32 { self.next_uint32() } @@ -115,21 +152,43 @@ mod tests { #[test] fn bitgen() -> PyResult<()> { - Python::with_gil(|py| { - let bitgen = get_bit_generator(py)?.bit_gen()?; - let _ = bitgen.next_raw(); - Ok(()) - }) + let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; + let _ = bitgen.next_raw(); + std::mem::drop(bitgen); + Ok(()) } + /// Test that the `rand::Rng` APIs work #[cfg(feature = "rand")] #[test] fn rand() -> PyResult<()> { use rand::Rng as _; + let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; + assert!(bitgen.random_ratio(1, 1)); + assert!(!bitgen.random_ratio(0, 1)); + std::mem::drop(bitgen); + Ok(()) + } + /// Test that dropping the lock works while holding the GIL + #[test] + fn unlock_with_held_gil() -> PyResult<()> { + Python::with_gil(|py| { + let generator = get_bit_generator(py)?; + let mut bitgen = generator.lock()?; + let _ = bitgen.next_raw(); + std::mem::drop(bitgen); + Ok(()) + }) + } + + #[test] + fn double_lock_fails() -> PyResult<()> { Python::with_gil(|py| { - let mut bitgen = get_bit_generator(py)?.bit_gen()?; - let _ = bitgen.random_ratio(2, 3); + let generator = get_bit_generator(py)?; + let d1 = generator.lock()?; + assert!(generator.lock().is_err()); + std::mem::drop(d1); Ok(()) }) } From ee32246a5d7927ddfa39e3d40024981b0dbb7a38 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 14:31:39 +0200 Subject: [PATCH 13/46] docs --- src/random.rs | 75 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/src/random.rs b/src/random.rs index 56eddcf8b..c56b55acf 100644 --- a/src/random.rs +++ b/src/random.rs @@ -1,6 +1,33 @@ -//! Safe interface for NumPy's random [`BitGenerator`][bg] +//! Safe interface for NumPy's random [`BitGenerator`][bg]. +//! +//! Using the patterns described in [“Extending `numpy.random`”][ext], +//! you can generate random numbers without holding the GIL, +//! by [acquiring][`PyBitGeneratorMethods::lock`] a [lock][`PyBitGeneratorLock`] for the [`PyBitGenerator`]: +//! +//! ```rust +//! use pyo3::prelude::*; +//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; +//! +//! let mut bitgen = Python::with_gil(|py| -> PyResult<_> { +//! let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?; +//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::()?; +//! bit_generator.lock() +//! })?; +//! let random_number = bitgen.next_u64(); +//! ``` +//! +//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorLock`]: +//! +//! ```rust +//! use rand::Rng as _; +//! +//! if bitgen.random_ratio(1, 1_000_000) { +//! println!("a sure thing"); +//! } +//! ``` //! //! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html +//! [ext]: https://numpy.org/doc/stable/reference/random/extending.html use std::ptr::NonNull; @@ -15,7 +42,9 @@ use pyo3::{ use crate::npyffi::npy_bitgen; -///! Wrapper for [`np.random.BitGenerator`][bg] +/// Wrapper for [`np.random.BitGenerator`][bg]. +/// +/// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html #[repr(transparent)] pub struct PyBitGenerator(PyAny); @@ -42,13 +71,13 @@ unsafe impl PyTypeInfo for PyBitGenerator { } } -/// Methods for [`PyBitGenerator`] -pub trait BitGeneratorMethods { - /// Acquire a lock on the BitGenerator to allow calling its methods in +/// Methods for [`PyBitGenerator`]. +pub trait PyBitGeneratorMethods { + /// Acquire a lock on the BitGenerator to allow calling its methods in. fn lock(&self) -> PyResult; } -impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> { +impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; @@ -66,7 +95,10 @@ impl<'py> BitGeneratorMethods for Bound<'py, PyBitGenerator> { return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); } }; - Ok(PyBitGeneratorLock(non_null, lock.unbind())) + Ok(PyBitGeneratorLock { + raw_bitgen: non_null, + lock: lock.unbind(), + }) } } @@ -77,36 +109,39 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock { } } -/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL -pub struct PyBitGeneratorLock(NonNull, Py); +/// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. +pub struct PyBitGeneratorLock { + raw_bitgen: NonNull, + lock: Py, +} // SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state impl PyBitGeneratorLock { - /// Returns the next random unsigned 64 bit integer + /// Returns the next random unsigned 64 bit integer. pub fn next_uint64(&mut self) -> u64 { unsafe { - let bitgen = self.0.as_mut(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_uint64)(bitgen.state) } } - /// Returns the next random unsigned 32 bit integer + /// Returns the next random unsigned 32 bit integer. pub fn next_uint32(&mut self) -> u32 { unsafe { - let bitgen = self.0.as_mut(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_uint32)(bitgen.state) } } - /// Returns the next random double + /// Returns the next random double. pub fn next_double(&mut self) -> libc::c_double { unsafe { - let bitgen = self.0.as_mut(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_double)(bitgen.state) } } - /// Returns the next raw value (can be used for testing) + /// Returns the next raw value (can be used for testing). pub fn next_raw(&mut self) -> u64 { unsafe { - let bitgen = self.0.as_mut(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_raw)(bitgen.state) } } @@ -115,7 +150,7 @@ impl PyBitGeneratorLock { impl Drop for PyBitGeneratorLock { fn drop(&mut self) { let r = Python::with_gil(|py| -> PyResult<()> { - self.1.bind(py).getattr("release")?.call0()?; + self.lock.bind(py).getattr("release")?.call0()?; Ok(()) }); if let Err(e) = r { @@ -142,9 +177,8 @@ mod tests { use super::*; fn get_bit_generator<'py>(py: Python<'py>) -> PyResult> { - let default_rng = py.import("numpy.random")?.getattr("default_rng")?; + let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?; let bit_generator = default_rng - .call0()? .getattr("bit_generator")? .downcast_into::()?; Ok(bit_generator) @@ -163,6 +197,7 @@ mod tests { #[test] fn rand() -> PyResult<()> { use rand::Rng as _; + let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; assert!(bitgen.random_ratio(1, 1)); assert!(!bitgen.random_ratio(0, 1)); From 1be6838e469b1f17c362b0447b6184a3dab9f2ae Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 14:32:35 +0200 Subject: [PATCH 14/46] more docs --- src/random.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/random.rs b/src/random.rs index c56b55acf..2f3a44019 100644 --- a/src/random.rs +++ b/src/random.rs @@ -44,6 +44,8 @@ use crate::npyffi::npy_bitgen; /// Wrapper for [`np.random.BitGenerator`][bg]. /// +/// See also [`PyBitGeneratorMethods`]. +/// /// [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html #[repr(transparent)] pub struct PyBitGenerator(PyAny); From 2aa3d900e7390e1eaa5b543348b899c78e7a92a1 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 14:43:39 +0200 Subject: [PATCH 15/46] guard --- src/random.rs | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/random.rs b/src/random.rs index 2f3a44019..168aa7983 100644 --- a/src/random.rs +++ b/src/random.rs @@ -2,7 +2,7 @@ //! //! Using the patterns described in [“Extending `numpy.random`”][ext], //! you can generate random numbers without holding the GIL, -//! by [acquiring][`PyBitGeneratorMethods::lock`] a [lock][`PyBitGeneratorLock`] for the [`PyBitGenerator`]: +//! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]: //! //! ```rust //! use pyo3::prelude::*; @@ -16,7 +16,7 @@ //! let random_number = bitgen.next_u64(); //! ``` //! -//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorLock`]: +//! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]: //! //! ```rust //! use rand::Rng as _; @@ -76,11 +76,11 @@ unsafe impl PyTypeInfo for PyBitGenerator { /// Methods for [`PyBitGenerator`]. pub trait PyBitGeneratorMethods { /// Acquire a lock on the BitGenerator to allow calling its methods in. - fn lock(&self) -> PyResult; + fn lock(&self) -> PyResult; } impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { - fn lock(&self) -> PyResult { + fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; if lock.getattr("locked")?.call0()?.extract()? { @@ -97,14 +97,14 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); } }; - Ok(PyBitGeneratorLock { + Ok(PyBitGeneratorGuard { raw_bitgen: non_null, lock: lock.unbind(), }) } } -impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock { +impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { type Error = PyErr; fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result { value.lock() @@ -112,13 +112,14 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorLock { } /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. -pub struct PyBitGeneratorLock { +pub struct PyBitGeneratorGuard { raw_bitgen: NonNull, lock: Py, } -// SAFETY for all methods: We hold the BitGenerator lock, so nothing is allowed to change its state -impl PyBitGeneratorLock { +// SAFETY: We hold the `BitGenerator.lock`, +// so nothing apart from us is allowed to change its state. +impl PyBitGeneratorGuard { /// Returns the next random unsigned 64 bit integer. pub fn next_uint64(&mut self) -> u64 { unsafe { @@ -149,7 +150,7 @@ impl PyBitGeneratorLock { } } -impl Drop for PyBitGeneratorLock { +impl Drop for PyBitGeneratorGuard { fn drop(&mut self) { let r = Python::with_gil(|py| -> PyResult<()> { self.lock.bind(py).getattr("release")?.call0()?; @@ -162,7 +163,7 @@ impl Drop for PyBitGeneratorLock { } #[cfg(feature = "rand")] -impl rand::RngCore for PyBitGeneratorLock { +impl rand::RngCore for PyBitGeneratorGuard { fn next_u32(&mut self) -> u32 { self.next_uint32() } @@ -207,7 +208,7 @@ mod tests { Ok(()) } - /// Test that dropping the lock works while holding the GIL + /// Test that releasing the lock works while holding the GIL #[test] fn unlock_with_held_gil() -> PyResult<()> { Python::with_gil(|py| { @@ -215,6 +216,11 @@ mod tests { let mut bitgen = generator.lock()?; let _ = bitgen.next_raw(); std::mem::drop(bitgen); + assert!(!generator + .getattr("lock")? + .getattr("locked")? + .call0()? + .extract()?); Ok(()) }) } From 0258e6d1187e05016032581e1226cac7cb208485 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 17:01:27 +0200 Subject: [PATCH 16/46] call_method0 --- src/random.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/random.rs b/src/random.rs index 168aa7983..99f1813be 100644 --- a/src/random.rs +++ b/src/random.rs @@ -9,7 +9,7 @@ //! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; //! //! let mut bitgen = Python::with_gil(|py| -> PyResult<_> { -//! let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?; +//! let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; //! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::()?; //! bit_generator.lock() //! })?; @@ -83,17 +83,17 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; - if lock.getattr("locked")?.call0()?.extract()? { + if lock.call_method0("locked")?.extract()? { return Err(PyRuntimeError::new_err("BitGenerator is already locked")); } - lock.getattr("acquire")?.call0()?; + lock.call_method0("acquire")?; assert_eq!(capsule.name()?, Some(c"BitGenerator")); let ptr = capsule.pointer() as *mut npy_bitgen; let non_null = match NonNull::new(ptr) { Some(non_null) => non_null, None => { - lock.getattr("release")?.call0()?; + lock.call_method0("release")?; return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); } }; @@ -153,7 +153,7 @@ impl PyBitGeneratorGuard { impl Drop for PyBitGeneratorGuard { fn drop(&mut self) { let r = Python::with_gil(|py| -> PyResult<()> { - self.lock.bind(py).getattr("release")?.call0()?; + self.lock.bind(py).call_method0("release")?; Ok(()) }); if let Err(e) = r { @@ -180,7 +180,7 @@ mod tests { use super::*; fn get_bit_generator<'py>(py: Python<'py>) -> PyResult> { - let default_rng = py.import("numpy.random")?.getattr("default_rng")?.call0()?; + let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; let bit_generator = default_rng .getattr("bit_generator")? .downcast_into::()?; @@ -218,8 +218,7 @@ mod tests { std::mem::drop(bitgen); assert!(!generator .getattr("lock")? - .getattr("locked")? - .call0()? + .call_method0("locked")? .extract()?); Ok(()) }) From 876001bf25a7b2d5e39c20d9499468a9b9fdd289 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 17:08:11 +0200 Subject: [PATCH 17/46] reaname test --- src/random.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 99f1813be..ea45a2b5b 100644 --- a/src/random.rs +++ b/src/random.rs @@ -187,8 +187,9 @@ mod tests { Ok(bit_generator) } + /// Test the primary use case: acquire the lock, release the GIL, then use the lock #[test] - fn bitgen() -> PyResult<()> { + fn use_outside_gil() -> PyResult<()> { let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; let _ = bitgen.next_raw(); std::mem::drop(bitgen); From 71ce8be2117256537bc7b428c2fa0b2b7f59fec1 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 17:33:01 +0200 Subject: [PATCH 18/46] manually drop and capsule --- src/random.rs | 85 ++++++++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/src/random.rs b/src/random.rs index ea45a2b5b..491552256 100644 --- a/src/random.rs +++ b/src/random.rs @@ -74,13 +74,13 @@ unsafe impl PyTypeInfo for PyBitGenerator { } /// Methods for [`PyBitGenerator`]. -pub trait PyBitGeneratorMethods { +pub trait PyBitGeneratorMethods<'py> { /// Acquire a lock on the BitGenerator to allow calling its methods in. - fn lock(&self) -> PyResult; + fn lock(&self) -> PyResult>; } -impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { - fn lock(&self) -> PyResult { +impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { + fn lock(&self) -> PyResult> { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; if lock.call_method0("locked")?.extract()? { @@ -99,12 +99,13 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { }; Ok(PyBitGeneratorGuard { raw_bitgen: non_null, - lock: lock.unbind(), + capsule, + lock, }) } } -impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { +impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { type Error = PyErr; fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result { value.lock() @@ -112,14 +113,30 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { } /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. -pub struct PyBitGeneratorGuard { +pub struct PyBitGeneratorGuard<'py> { raw_bitgen: NonNull, - lock: Py, + capsule: Bound<'py, PyCapsule>, + lock: Bound<'py, PyAny>, +} + +unsafe impl Send for PyBitGeneratorGuard<'_> {} + +impl Drop for PyBitGeneratorGuard<'_> { + fn drop(&mut self) { + // ignore errors. This includes when `try_drop` was called manually + let _ = self.lock.call_method0("release"); + } } // SAFETY: We hold the `BitGenerator.lock`, // so nothing apart from us is allowed to change its state. -impl PyBitGeneratorGuard { +impl PyBitGeneratorGuard<'_> { + /// Drop the lock, allowing access to. + pub fn try_drop(self) -> PyResult<()> { + self.lock.call_method0("release")?; + Ok(()) + } + /// Returns the next random unsigned 64 bit integer. pub fn next_uint64(&mut self) -> u64 { unsafe { @@ -150,20 +167,8 @@ impl PyBitGeneratorGuard { } } -impl Drop for PyBitGeneratorGuard { - fn drop(&mut self) { - let r = Python::with_gil(|py| -> PyResult<()> { - self.lock.bind(py).call_method0("release")?; - Ok(()) - }); - if let Err(e) = r { - eprintln!("Failed to release BitGenerator lock: {e}"); - } - } -} - #[cfg(feature = "rand")] -impl rand::RngCore for PyBitGeneratorGuard { +impl rand::RngCore for PyBitGeneratorGuard<'_> { fn next_u32(&mut self) -> u32 { self.next_uint32() } @@ -190,10 +195,14 @@ mod tests { /// Test the primary use case: acquire the lock, release the GIL, then use the lock #[test] fn use_outside_gil() -> PyResult<()> { - let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; - let _ = bitgen.next_raw(); - std::mem::drop(bitgen); - Ok(()) + Python::with_gil(|py| { + let mut bitgen = get_bit_generator(py)?.lock()?; + py.allow_threads(|| { + let _ = bitgen.next_raw(); + }); + assert!(bitgen.try_drop().is_ok()); + Ok(()) + }) } /// Test that the `rand::Rng` APIs work @@ -202,11 +211,15 @@ mod tests { fn rand() -> PyResult<()> { use rand::Rng as _; - let mut bitgen = Python::with_gil(|py| get_bit_generator(py)?.lock())?; - assert!(bitgen.random_ratio(1, 1)); - assert!(!bitgen.random_ratio(0, 1)); - std::mem::drop(bitgen); - Ok(()) + Python::with_gil(|py| { + let mut bitgen = get_bit_generator(py)?.lock()?; + py.allow_threads(|| { + assert!(bitgen.random_ratio(1, 1)); + assert!(!bitgen.random_ratio(0, 1)); + }); + assert!(bitgen.try_drop().is_ok()); + Ok(()) + }) } /// Test that releasing the lock works while holding the GIL @@ -216,11 +229,7 @@ mod tests { let generator = get_bit_generator(py)?; let mut bitgen = generator.lock()?; let _ = bitgen.next_raw(); - std::mem::drop(bitgen); - assert!(!generator - .getattr("lock")? - .call_method0("locked")? - .extract()?); + assert!(bitgen.try_drop().is_ok()); Ok(()) }) } @@ -229,9 +238,9 @@ mod tests { fn double_lock_fails() -> PyResult<()> { Python::with_gil(|py| { let generator = get_bit_generator(py)?; - let d1 = generator.lock()?; + let bitgen = generator.lock()?; assert!(generator.lock().is_err()); - std::mem::drop(d1); + assert!(bitgen.try_drop().is_ok()); Ok(()) }) } From 2de7072787a85c87125cfa2224e9c795e6bd7391 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 17:34:45 +0200 Subject: [PATCH 19/46] remove useless test --- src/random.rs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/random.rs b/src/random.rs index 491552256..5ad691feb 100644 --- a/src/random.rs +++ b/src/random.rs @@ -222,18 +222,6 @@ mod tests { }) } - /// Test that releasing the lock works while holding the GIL - #[test] - fn unlock_with_held_gil() -> PyResult<()> { - Python::with_gil(|py| { - let generator = get_bit_generator(py)?; - let mut bitgen = generator.lock()?; - let _ = bitgen.next_raw(); - assert!(bitgen.try_drop().is_ok()); - Ok(()) - }) - } - #[test] fn double_lock_fails() -> PyResult<()> { Python::with_gil(|py| { From 016eb7ae88d0d0a6a3274121166d73836562d1db Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 17:56:42 +0200 Subject: [PATCH 20/46] doctests --- src/random.rs | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/random.rs b/src/random.rs index 5ad691feb..6db525923 100644 --- a/src/random.rs +++ b/src/random.rs @@ -4,26 +4,44 @@ //! you can generate random numbers without holding the GIL, //! by [acquiring][`PyBitGeneratorMethods::lock`] a lock [guard][`PyBitGeneratorGuard`] for the [`PyBitGenerator`]: //! -//! ```rust +//! ``` //! use pyo3::prelude::*; //! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; //! -//! let mut bitgen = Python::with_gil(|py| -> PyResult<_> { +//! fn default_bit_gen<'py>(py: Python<'py>) -> PyResult> { //! let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; -//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into::()?; -//! bit_generator.lock() +//! let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?; +//! Ok(bit_generator) +//! } +//! +//! let random_number = Python::with_gil(|py| -> PyResult<_> { +//! let mut bitgen = default_bit_gen(py)?.lock()?; +//! Ok(bitgen.next_uint64()) //! })?; -//! let random_number = bitgen.next_u64(); +//! # Ok::<(), PyErr>(()) //! ``` //! //! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]: //! -//! ```rust +//! ``` +//! use pyo3::prelude::*; //! use rand::Rng as _; +//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; +//! # // TODO: reuse function definition from above? +//! # fn default_bit_gen<'py>(py: Python<'py>) -> PyResult> { +//! # let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; +//! # let bit_generator = default_rng.getattr("bit_generator")?.downcast_into()?; +//! # Ok(bit_generator) +//! # } //! -//! if bitgen.random_ratio(1, 1_000_000) { -//! println!("a sure thing"); -//! } +//! Python::with_gil(|py| -> PyResult<_> { +//! let mut bitgen = default_bit_gen(py)?.lock()?; +//! if bitgen.random_ratio(1, 1_000_000) { +//! println!("a sure thing"); +//! } +//! Ok(()) +//! })?; +//! # Ok::<(), PyErr>(()) //! ``` //! //! [bg]: https://numpy.org/doc/stable//reference/random/bit_generators/generated/numpy.random.BitGenerator.html @@ -99,7 +117,7 @@ impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { }; Ok(PyBitGeneratorGuard { raw_bitgen: non_null, - capsule, + _capsule: capsule, lock, }) } @@ -115,7 +133,7 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. pub struct PyBitGeneratorGuard<'py> { raw_bitgen: NonNull, - capsule: Bound<'py, PyCapsule>, + _capsule: Bound<'py, PyCapsule>, lock: Bound<'py, PyAny>, } @@ -140,28 +158,28 @@ impl PyBitGeneratorGuard<'_> { /// Returns the next random unsigned 64 bit integer. pub fn next_uint64(&mut self) -> u64 { unsafe { - let bitgen = self.raw_bitgen.as_mut(); + let bitgen = *self.raw_bitgen.as_ptr(); (bitgen.next_uint64)(bitgen.state) } } /// Returns the next random unsigned 32 bit integer. pub fn next_uint32(&mut self) -> u32 { unsafe { - let bitgen = self.raw_bitgen.as_mut(); + let bitgen = *self.raw_bitgen.as_ptr(); (bitgen.next_uint32)(bitgen.state) } } /// Returns the next random double. pub fn next_double(&mut self) -> libc::c_double { unsafe { - let bitgen = self.raw_bitgen.as_mut(); + let bitgen = *self.raw_bitgen.as_ptr(); (bitgen.next_double)(bitgen.state) } } /// Returns the next raw value (can be used for testing). pub fn next_raw(&mut self) -> u64 { unsafe { - let bitgen = self.raw_bitgen.as_mut(); + let bitgen = *self.raw_bitgen.as_ptr(); (bitgen.next_raw)(bitgen.state) } } From 1f7f37fa2f16603fdb5303c4183975399aec7489 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 17:58:41 +0200 Subject: [PATCH 21/46] smaller --- src/random.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/random.rs b/src/random.rs index 6db525923..75a6e0214 100644 --- a/src/random.rs +++ b/src/random.rs @@ -24,9 +24,9 @@ //! With the [`rand`] crate installed, you can also use the [`rand::Rng`] APIs from the [`PyBitGeneratorGuard`]: //! //! ``` -//! use pyo3::prelude::*; +//! # use pyo3::prelude::*; //! use rand::Rng as _; -//! use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; +//! # use numpy::random::{PyBitGenerator, PyBitGeneratorMethods as _}; //! # // TODO: reuse function definition from above? //! # fn default_bit_gen<'py>(py: Python<'py>) -> PyResult> { //! # let default_rng = py.import("numpy.random")?.call_method0("default_rng")?; From 1d01c7a601ac1bdf3ec0ec47e0f9ad12c7920ad0 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 18:02:42 +0200 Subject: [PATCH 22/46] clarify where to release the GIL --- src/random.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 75a6e0214..6691aa181 100644 --- a/src/random.rs +++ b/src/random.rs @@ -16,7 +16,8 @@ //! //! let random_number = Python::with_gil(|py| -> PyResult<_> { //! let mut bitgen = default_bit_gen(py)?.lock()?; -//! Ok(bitgen.next_uint64()) +//! // use bitgen without holding the GIL +//! Ok(py.allow_threads(|| bitgen.next_uint64())) //! })?; //! # Ok::<(), PyErr>(()) //! ``` From c90176ae45314fcac44312c94f7120cd5169e232 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 18:09:37 +0200 Subject: [PATCH 23/46] safety --- src/random.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/random.rs b/src/random.rs index 6691aa181..8df9a5d05 100644 --- a/src/random.rs +++ b/src/random.rs @@ -138,6 +138,8 @@ pub struct PyBitGeneratorGuard<'py> { lock: Bound<'py, PyAny>, } +// SAFETY: we can’t have public APIs that access the Python objects, +// only the `raw_bitgen` pointer. unsafe impl Send for PyBitGeneratorGuard<'_> {} impl Drop for PyBitGeneratorGuard<'_> { @@ -150,8 +152,10 @@ impl Drop for PyBitGeneratorGuard<'_> { // SAFETY: We hold the `BitGenerator.lock`, // so nothing apart from us is allowed to change its state. impl PyBitGeneratorGuard<'_> { - /// Drop the lock, allowing access to. - pub fn try_drop(self) -> PyResult<()> { + /// Drop the lock manually before `Drop::drop` tries to do it (used for testing). + /// SAFETY: Can’t be used inside of a + #[allow(dead_code)] + unsafe fn try_drop(self) -> PyResult<()> { self.lock.call_method0("release")?; Ok(()) } @@ -219,7 +223,7 @@ mod tests { py.allow_threads(|| { let _ = bitgen.next_raw(); }); - assert!(bitgen.try_drop().is_ok()); + assert!(unsafe { bitgen.try_drop() }.is_ok()); Ok(()) }) } @@ -236,7 +240,7 @@ mod tests { assert!(bitgen.random_ratio(1, 1)); assert!(!bitgen.random_ratio(0, 1)); }); - assert!(bitgen.try_drop().is_ok()); + assert!(unsafe { bitgen.try_drop() }.is_ok()); Ok(()) }) } @@ -247,7 +251,7 @@ mod tests { let generator = get_bit_generator(py)?; let bitgen = generator.lock()?; assert!(generator.lock().is_err()); - assert!(bitgen.try_drop().is_ok()); + assert!(unsafe { bitgen.try_drop() }.is_ok()); Ok(()) }) } From f49d3fa19a0b6747b80d93ec49ad9e44fb950031 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 18:10:57 +0200 Subject: [PATCH 24/46] oops --- src/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 8df9a5d05..dfec42d7e 100644 --- a/src/random.rs +++ b/src/random.rs @@ -153,7 +153,7 @@ impl Drop for PyBitGeneratorGuard<'_> { // so nothing apart from us is allowed to change its state. impl PyBitGeneratorGuard<'_> { /// Drop the lock manually before `Drop::drop` tries to do it (used for testing). - /// SAFETY: Can’t be used inside of a + /// SAFETY: Can’t be used inside of a `Python::allow_threads` block. #[allow(dead_code)] unsafe fn try_drop(self) -> PyResult<()> { self.lock.call_method0("release")?; From a16846db852b24ed7846922146ea45e20a50ec4f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 18:30:33 +0200 Subject: [PATCH 25/46] less unsafe --- src/random.rs | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/random.rs b/src/random.rs index dfec42d7e..1fd47a755 100644 --- a/src/random.rs +++ b/src/random.rs @@ -118,8 +118,9 @@ impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { }; Ok(PyBitGeneratorGuard { raw_bitgen: non_null, - _capsule: capsule, - lock, + _capsule: capsule.unbind(), + lock: lock.unbind(), + py: self.py(), }) } } @@ -134,8 +135,14 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. pub struct PyBitGeneratorGuard<'py> { raw_bitgen: NonNull, - _capsule: Bound<'py, PyCapsule>, - lock: Bound<'py, PyAny>, + /// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. + _capsule: Py, + /// This lock makes sure no other threads try to use the BitGenerator while we do. + lock: Py, + /// This should be an unsafe field (https://github.com/rust-lang/rust/issues/132922) + /// + /// SAFETY: only use this in `Drop::drop` (when we are sure the GIL is held). + py: Python<'py>, } // SAFETY: we can’t have public APIs that access the Python objects, @@ -145,18 +152,17 @@ unsafe impl Send for PyBitGeneratorGuard<'_> {} impl Drop for PyBitGeneratorGuard<'_> { fn drop(&mut self) { // ignore errors. This includes when `try_drop` was called manually - let _ = self.lock.call_method0("release"); + let _ = self.lock.bind(self.py).call_method0("release"); } } // SAFETY: We hold the `BitGenerator.lock`, // so nothing apart from us is allowed to change its state. -impl PyBitGeneratorGuard<'_> { +impl<'py> PyBitGeneratorGuard<'py> { /// Drop the lock manually before `Drop::drop` tries to do it (used for testing). - /// SAFETY: Can’t be used inside of a `Python::allow_threads` block. #[allow(dead_code)] - unsafe fn try_drop(self) -> PyResult<()> { - self.lock.call_method0("release")?; + fn try_drop(self, py: Python<'py>) -> PyResult<()> { + self.lock.bind(py).call_method0("release")?; Ok(()) } @@ -223,7 +229,7 @@ mod tests { py.allow_threads(|| { let _ = bitgen.next_raw(); }); - assert!(unsafe { bitgen.try_drop() }.is_ok()); + assert!(bitgen.try_drop(py).is_ok()); Ok(()) }) } @@ -240,7 +246,7 @@ mod tests { assert!(bitgen.random_ratio(1, 1)); assert!(!bitgen.random_ratio(0, 1)); }); - assert!(unsafe { bitgen.try_drop() }.is_ok()); + assert!(bitgen.try_drop(py).is_ok()); Ok(()) }) } @@ -251,7 +257,7 @@ mod tests { let generator = get_bit_generator(py)?; let bitgen = generator.lock()?; assert!(generator.lock().is_err()); - assert!(unsafe { bitgen.try_drop() }.is_ok()); + assert!(bitgen.try_drop(py).is_ok()); Ok(()) }) } From 573d890208d067f71a5178829b3ea79f119268ab Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 20:03:59 +0200 Subject: [PATCH 26/46] add thread test --- src/random.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/random.rs b/src/random.rs index 1fd47a755..b8a855f0c 100644 --- a/src/random.rs +++ b/src/random.rs @@ -234,6 +234,38 @@ mod tests { }) } + /// More complex version of primary use case: use from multiple threads + #[cfg(feature = "rand")] + #[test] + fn use_parallel() -> PyResult<()> { + use crate::array::{PyArray2, PyArrayMethods as _}; + use ndarray::Dimension; + use rand::Rng; + use std::sync::{Arc, Mutex}; + + Python::with_gil(|py| -> PyResult<_> { + let mut arr = PyArray2::::zeros(py, (2, 300), false).readwrite(); + let bitgen = get_bit_generator(py)?.lock()?; + let bitgen = Arc::new(Mutex::new(bitgen)); + + let (_n_threads, chunk_size) = arr.dims().into_pattern(); + let slice = arr.as_slice_mut()?; + + Python::allow_threads(py, || { + std::thread::scope(|s| { + for chunk in slice.chunks_exact_mut(chunk_size) { + let bitgen = Arc::clone(&bitgen); + s.spawn(move || { + let mut bitgen = bitgen.lock().unwrap(); + chunk.fill_with(|| bitgen.random_range(10..200)); + }); + } + }) + }); + Ok(()) + }) + } + /// Test that the `rand::Rng` APIs work #[cfg(feature = "rand")] #[test] From 06bb693dd895f49e83f90d664dfdfdac27970816 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 8 Jun 2025 20:27:32 +0200 Subject: [PATCH 27/46] back to lock acquiring --- src/random.rs | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/random.rs b/src/random.rs index b8a855f0c..f9c02f79c 100644 --- a/src/random.rs +++ b/src/random.rs @@ -95,11 +95,11 @@ unsafe impl PyTypeInfo for PyBitGenerator { /// Methods for [`PyBitGenerator`]. pub trait PyBitGeneratorMethods<'py> { /// Acquire a lock on the BitGenerator to allow calling its methods in. - fn lock(&self) -> PyResult>; + fn lock(&self) -> PyResult; } impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { - fn lock(&self) -> PyResult> { + fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; if lock.call_method0("locked")?.extract()? { @@ -120,12 +120,11 @@ impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { raw_bitgen: non_null, _capsule: capsule.unbind(), lock: lock.unbind(), - py: self.py(), }) } } -impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { +impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { type Error = PyErr; fn try_from(value: &Bound<'py, PyBitGenerator>) -> Result { value.lock() @@ -133,35 +132,34 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard<'py> { } /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. -pub struct PyBitGeneratorGuard<'py> { +pub struct PyBitGeneratorGuard { raw_bitgen: NonNull, /// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. _capsule: Py, /// This lock makes sure no other threads try to use the BitGenerator while we do. lock: Py, - /// This should be an unsafe field (https://github.com/rust-lang/rust/issues/132922) - /// - /// SAFETY: only use this in `Drop::drop` (when we are sure the GIL is held). - py: Python<'py>, } // SAFETY: we can’t have public APIs that access the Python objects, // only the `raw_bitgen` pointer. -unsafe impl Send for PyBitGeneratorGuard<'_> {} +unsafe impl Send for PyBitGeneratorGuard {} -impl Drop for PyBitGeneratorGuard<'_> { +impl Drop for PyBitGeneratorGuard { fn drop(&mut self) { - // ignore errors. This includes when `try_drop` was called manually - let _ = self.lock.bind(self.py).call_method0("release"); + // ignore errors. This includes when `try_release` was called manually. + let _ = Python::with_gil(|py| -> PyResult<_> { + self.lock.bind(py).call_method0("release")?; + Ok(()) + }); } } // SAFETY: We hold the `BitGenerator.lock`, // so nothing apart from us is allowed to change its state. -impl<'py> PyBitGeneratorGuard<'py> { - /// Drop the lock manually before `Drop::drop` tries to do it (used for testing). +impl<'py> PyBitGeneratorGuard { + /// Release the lock, allowing for checking for errors. #[allow(dead_code)] - fn try_drop(self, py: Python<'py>) -> PyResult<()> { + pub fn try_release(self, py: Python<'py>) -> PyResult<()> { self.lock.bind(py).call_method0("release")?; Ok(()) } @@ -197,7 +195,7 @@ impl<'py> PyBitGeneratorGuard<'py> { } #[cfg(feature = "rand")] -impl rand::RngCore for PyBitGeneratorGuard<'_> { +impl rand::RngCore for PyBitGeneratorGuard { fn next_u32(&mut self) -> u32 { self.next_uint32() } @@ -229,7 +227,7 @@ mod tests { py.allow_threads(|| { let _ = bitgen.next_raw(); }); - assert!(bitgen.try_drop(py).is_ok()); + assert!(bitgen.try_release(py).is_ok()); Ok(()) }) } @@ -278,7 +276,7 @@ mod tests { assert!(bitgen.random_ratio(1, 1)); assert!(!bitgen.random_ratio(0, 1)); }); - assert!(bitgen.try_drop(py).is_ok()); + assert!(bitgen.try_release(py).is_ok()); Ok(()) }) } @@ -289,7 +287,7 @@ mod tests { let generator = get_bit_generator(py)?; let bitgen = generator.lock()?; assert!(generator.lock().is_err()); - assert!(bitgen.try_drop(py).is_ok()); + assert!(bitgen.try_release(py).is_ok()); Ok(()) }) } From 663fa291d319e7a04579f67adf9f602dc328054d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 9 Jun 2025 13:44:52 +0200 Subject: [PATCH 28/46] docs --- src/random.rs | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/random.rs b/src/random.rs index f9c02f79c..0819a3ff8 100644 --- a/src/random.rs +++ b/src/random.rs @@ -102,6 +102,7 @@ impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; + // we’re holding the GIL, so there’s no race condition checking the lock and acquiring it later. if lock.call_method0("locked")?.extract()? { return Err(PyRuntimeError::new_err("BitGenerator is already locked")); } @@ -109,12 +110,9 @@ impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { assert_eq!(capsule.name()?, Some(c"BitGenerator")); let ptr = capsule.pointer() as *mut npy_bitgen; - let non_null = match NonNull::new(ptr) { - Some(non_null) => non_null, - None => { - lock.call_method0("release")?; - return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); - } + let Some(non_null) = NonNull::new(ptr) else { + lock.call_method0("release")?; + return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); }; Ok(PyBitGeneratorGuard { raw_bitgen: non_null, @@ -140,8 +138,8 @@ pub struct PyBitGeneratorGuard { lock: Py, } -// SAFETY: we can’t have public APIs that access the Python objects, -// only the `raw_bitgen` pointer. +// SAFETY: 1. We don’t hold the GIL, so we can’t access the Python objects. +// 2. We only access `raw_bitgen` from `&mut self`, which protects it from parallel access. unsafe impl Send for PyBitGeneratorGuard {} impl Drop for PyBitGeneratorGuard { @@ -154,11 +152,10 @@ impl Drop for PyBitGeneratorGuard { } } -// SAFETY: We hold the `BitGenerator.lock`, -// so nothing apart from us is allowed to change its state. +// SAFETY: 1. We hold the `BitGenerator.lock`, so nothing apart from us is allowed to change its state. +// 2. We hold the `BitGenerator.capsule`, so it can’t be deallocated. impl<'py> PyBitGeneratorGuard { /// Release the lock, allowing for checking for errors. - #[allow(dead_code)] pub fn try_release(self, py: Python<'py>) -> PyResult<()> { self.lock.bind(py).call_method0("release")?; Ok(()) @@ -249,7 +246,7 @@ mod tests { let (_n_threads, chunk_size) = arr.dims().into_pattern(); let slice = arr.as_slice_mut()?; - Python::allow_threads(py, || { + py.allow_threads(|| { std::thread::scope(|s| { for chunk in slice.chunks_exact_mut(chunk_size) { let bitgen = Arc::clone(&bitgen); @@ -260,6 +257,8 @@ mod tests { } }) }); + + std::mem::drop(bitgen); Ok(()) }) } From c6105c91a6ef33b4dc5d30e78972a1ff53435b33 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 10:11:41 +0200 Subject: [PATCH 29/46] no copy/clone --- src/npyffi/random.rs | 2 +- src/random.rs | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index 6c401ac7a..66e1ec59d 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -1,7 +1,7 @@ use std::ffi::c_void; #[repr(C)] -#[derive(Debug, Clone, Copy)] // TODO: can it be Clone and/or Copy? +#[derive(Debug)] pub struct npy_bitgen { pub state: *mut c_void, pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil diff --git a/src/random.rs b/src/random.rs index 0819a3ff8..a8fd91608 100644 --- a/src/random.rs +++ b/src/random.rs @@ -162,30 +162,31 @@ impl<'py> PyBitGeneratorGuard { } /// Returns the next random unsigned 64 bit integer. - pub fn next_uint64(&mut self) -> u64 { + pub fn next_u64(&mut self) -> u64 { unsafe { - let bitgen = *self.raw_bitgen.as_ptr(); + // TODO: maybe use pointer offsets instead of `mut` + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_uint64)(bitgen.state) } } /// Returns the next random unsigned 32 bit integer. - pub fn next_uint32(&mut self) -> u32 { + pub fn next_u32(&mut self) -> u32 { unsafe { - let bitgen = *self.raw_bitgen.as_ptr(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_uint32)(bitgen.state) } } /// Returns the next random double. pub fn next_double(&mut self) -> libc::c_double { unsafe { - let bitgen = *self.raw_bitgen.as_ptr(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_double)(bitgen.state) } } /// Returns the next raw value (can be used for testing). pub fn next_raw(&mut self) -> u64 { unsafe { - let bitgen = *self.raw_bitgen.as_ptr(); + let bitgen = self.raw_bitgen.as_mut(); (bitgen.next_raw)(bitgen.state) } } @@ -194,10 +195,10 @@ impl<'py> PyBitGeneratorGuard { #[cfg(feature = "rand")] impl rand::RngCore for PyBitGeneratorGuard { fn next_u32(&mut self) -> u32 { - self.next_uint32() + self.next_u32() } fn next_u64(&mut self) -> u64 { - self.next_uint64() + self.next_u64() } fn fill_bytes(&mut self, dst: &mut [u8]) { rand::rand_core::impls::fill_bytes_via_next(self, dst) From 3a0aa925d60d75c028be052f8cdced5ba1b473cb Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 10:12:12 +0200 Subject: [PATCH 30/46] rename to release --- src/random.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/random.rs b/src/random.rs index a8fd91608..0d9ccf025 100644 --- a/src/random.rs +++ b/src/random.rs @@ -156,7 +156,7 @@ impl Drop for PyBitGeneratorGuard { // 2. We hold the `BitGenerator.capsule`, so it can’t be deallocated. impl<'py> PyBitGeneratorGuard { /// Release the lock, allowing for checking for errors. - pub fn try_release(self, py: Python<'py>) -> PyResult<()> { + pub fn release(self, py: Python<'py>) -> PyResult<()> { self.lock.bind(py).call_method0("release")?; Ok(()) } @@ -225,7 +225,7 @@ mod tests { py.allow_threads(|| { let _ = bitgen.next_raw(); }); - assert!(bitgen.try_release(py).is_ok()); + assert!(bitgen.release(py).is_ok()); Ok(()) }) } @@ -276,7 +276,7 @@ mod tests { assert!(bitgen.random_ratio(1, 1)); assert!(!bitgen.random_ratio(0, 1)); }); - assert!(bitgen.try_release(py).is_ok()); + assert!(bitgen.release(py).is_ok()); Ok(()) }) } @@ -287,7 +287,7 @@ mod tests { let generator = get_bit_generator(py)?; let bitgen = generator.lock()?; assert!(generator.lock().is_err()); - assert!(bitgen.try_release(py).is_ok()); + assert!(bitgen.release(py).is_ok()); Ok(()) }) } From a92861a6b8c5580952fbb37c448884849a702eac Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 10:13:24 +0200 Subject: [PATCH 31/46] remove lifetime --- src/random.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/random.rs b/src/random.rs index 0d9ccf025..b13b74140 100644 --- a/src/random.rs +++ b/src/random.rs @@ -93,12 +93,12 @@ unsafe impl PyTypeInfo for PyBitGenerator { } /// Methods for [`PyBitGenerator`]. -pub trait PyBitGeneratorMethods<'py> { +pub trait PyBitGeneratorMethods { /// Acquire a lock on the BitGenerator to allow calling its methods in. fn lock(&self) -> PyResult; } -impl<'py> PyBitGeneratorMethods<'py> for Bound<'py, PyBitGenerator> { +impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { fn lock(&self) -> PyResult { let capsule = self.getattr("capsule")?.downcast_into::()?; let lock = self.getattr("lock")?; From 6dbb6dc92b5de5bc9958186014f02c6a7e710de5 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 10:13:51 +0200 Subject: [PATCH 32/46] static --- src/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index b13b74140..220b40357 100644 --- a/src/random.rs +++ b/src/random.rs @@ -76,7 +76,7 @@ unsafe impl PyTypeInfo for PyBitGenerator { const MODULE: Option<&'static str> = Some("numpy.random"); fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { - const CLS: GILOnceCell> = GILOnceCell::new(); + static CLS: GILOnceCell> = GILOnceCell::new(); let cls = CLS .get_or_try_init::<_, PyErr>(py, || { Ok(py From b102d205fcddd0f4ce750eb3a35f3f215d7da940 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 10:42:51 +0200 Subject: [PATCH 33/46] no mut ref conversion --- src/random.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/random.rs b/src/random.rs index 220b40357..e24eea5ac 100644 --- a/src/random.rs +++ b/src/random.rs @@ -164,30 +164,29 @@ impl<'py> PyBitGeneratorGuard { /// Returns the next random unsigned 64 bit integer. pub fn next_u64(&mut self) -> u64 { unsafe { - // TODO: maybe use pointer offsets instead of `mut` - let bitgen = self.raw_bitgen.as_mut(); - (bitgen.next_uint64)(bitgen.state) + let bitgen = self.raw_bitgen.as_ptr(); + ((*bitgen).next_uint64)((*bitgen).state) } } /// Returns the next random unsigned 32 bit integer. pub fn next_u32(&mut self) -> u32 { unsafe { - let bitgen = self.raw_bitgen.as_mut(); - (bitgen.next_uint32)(bitgen.state) + let bitgen = self.raw_bitgen.as_ptr(); + ((*bitgen).next_uint32)((*bitgen).state) } } /// Returns the next random double. pub fn next_double(&mut self) -> libc::c_double { unsafe { - let bitgen = self.raw_bitgen.as_mut(); - (bitgen.next_double)(bitgen.state) + let bitgen = self.raw_bitgen.as_ptr(); + ((*bitgen).next_double)((*bitgen).state) } } /// Returns the next raw value (can be used for testing). pub fn next_raw(&mut self) -> u64 { unsafe { - let bitgen = self.raw_bitgen.as_mut(); - (bitgen.next_raw)(bitgen.state) + let bitgen = self.raw_bitgen.as_ptr(); + ((*bitgen).next_raw)((*bitgen).state) } } } From e5e440eafd43d7368f7641fcd06ae400225b260b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 10:47:36 +0200 Subject: [PATCH 34/46] disambiguate --- src/random.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/random.rs b/src/random.rs index e24eea5ac..1034270bb 100644 --- a/src/random.rs +++ b/src/random.rs @@ -194,10 +194,10 @@ impl<'py> PyBitGeneratorGuard { #[cfg(feature = "rand")] impl rand::RngCore for PyBitGeneratorGuard { fn next_u32(&mut self) -> u32 { - self.next_u32() + PyBitGeneratorGuard::next_u32(self) } fn next_u64(&mut self) -> u64 { - self.next_u64() + PyBitGeneratorGuard::next_u64(self) } fn fill_bytes(&mut self, dst: &mut [u8]) { rand::rand_core::impls::fill_bytes_via_next(self, dst) From e73e3a27bbf8e406586d61db4b4b8695dbe581ee Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 14:59:35 +0200 Subject: [PATCH 35/46] rand_core only --- Cargo.toml | 3 ++- src/random.rs | 8 +++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 636d17288..49f3716f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ num-integer = "0.1" num-traits = "0.2" ndarray = ">= 0.15, < 0.17" pyo3 = { version = "0.25.0", default-features = false, features = ["macros"] } -rand = { version = "0.9.1", default-features = false, optional = true } +rand_core = { version = "0.9.3", default-features = false, optional = true } rustc-hash = "2.0" [dev-dependencies] @@ -33,6 +33,7 @@ pyo3 = { version = "0.25", default-features = false, features = [ nalgebra = { version = ">=0.30, <0.34", default-features = false, features = [ "std", ] } +rand = { version = "0.9.1", default-features = false } [build-dependencies] pyo3-build-config = { version = "0.25", features = ["resolve-config"] } diff --git a/src/random.rs b/src/random.rs index 1034270bb..5c9bbea0e 100644 --- a/src/random.rs +++ b/src/random.rs @@ -191,8 +191,8 @@ impl<'py> PyBitGeneratorGuard { } } -#[cfg(feature = "rand")] -impl rand::RngCore for PyBitGeneratorGuard { +#[cfg(feature = "rand_core")] +impl rand_core::RngCore for PyBitGeneratorGuard { fn next_u32(&mut self) -> u32 { PyBitGeneratorGuard::next_u32(self) } @@ -200,7 +200,7 @@ impl rand::RngCore for PyBitGeneratorGuard { PyBitGeneratorGuard::next_u64(self) } fn fill_bytes(&mut self, dst: &mut [u8]) { - rand::rand_core::impls::fill_bytes_via_next(self, dst) + rand_core::impls::fill_bytes_via_next(self, dst) } } @@ -230,7 +230,6 @@ mod tests { } /// More complex version of primary use case: use from multiple threads - #[cfg(feature = "rand")] #[test] fn use_parallel() -> PyResult<()> { use crate::array::{PyArray2, PyArrayMethods as _}; @@ -264,7 +263,6 @@ mod tests { } /// Test that the `rand::Rng` APIs work - #[cfg(feature = "rand")] #[test] fn rand() -> PyResult<()> { use rand::Rng as _; From c6493dfc62f3f5071e15454f8328dfcd79b6a4bb Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 15:00:47 +0200 Subject: [PATCH 36/46] rename bitgen type --- src/npyffi/random.rs | 2 +- src/random.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/npyffi/random.rs b/src/npyffi/random.rs index 66e1ec59d..e618768aa 100644 --- a/src/npyffi/random.rs +++ b/src/npyffi/random.rs @@ -2,7 +2,7 @@ use std::ffi::c_void; #[repr(C)] #[derive(Debug)] -pub struct npy_bitgen { +pub struct bitgen_t { pub state: *mut c_void, pub next_uint64: unsafe extern "C" fn(*mut c_void) -> super::npy_uint64, //nogil pub next_uint32: unsafe extern "C" fn(*mut c_void) -> super::npy_uint32, //nogil diff --git a/src/random.rs b/src/random.rs index 5c9bbea0e..616f64e04 100644 --- a/src/random.rs +++ b/src/random.rs @@ -59,7 +59,7 @@ use pyo3::{ PyTypeInfo, }; -use crate::npyffi::npy_bitgen; +use crate::npyffi::bitgen_t; /// Wrapper for [`np.random.BitGenerator`][bg]. /// @@ -109,7 +109,7 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { lock.call_method0("acquire")?; assert_eq!(capsule.name()?, Some(c"BitGenerator")); - let ptr = capsule.pointer() as *mut npy_bitgen; + let ptr = capsule.pointer() as *mut bitgen_t; let Some(non_null) = NonNull::new(ptr) else { lock.call_method0("release")?; return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); @@ -131,7 +131,7 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. pub struct PyBitGeneratorGuard { - raw_bitgen: NonNull, + raw_bitgen: NonNull, /// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. _capsule: Py, /// This lock makes sure no other threads try to use the BitGenerator while we do. From 2327f360130a280a9cf5d883a81e3fa952b06558 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 15:02:59 +0200 Subject: [PATCH 37/46] c_str macro --- src/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 616f64e04..0dcfee96c 100644 --- a/src/random.rs +++ b/src/random.rs @@ -108,7 +108,7 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { } lock.call_method0("acquire")?; - assert_eq!(capsule.name()?, Some(c"BitGenerator")); + assert_eq!(capsule.name()?, Some(ffi::c_str!("BitGenerator"))); let ptr = capsule.pointer() as *mut bitgen_t; let Some(non_null) = NonNull::new(ptr) else { lock.call_method0("release")?; From e5c64588cca05378a3d718e4e2fa79b4da750b22 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 15:08:26 +0200 Subject: [PATCH 38/46] intern strings --- src/random.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/random.rs b/src/random.rs index 0dcfee96c..eb966a7e5 100644 --- a/src/random.rs +++ b/src/random.rs @@ -52,7 +52,7 @@ use std::ptr::NonNull; use pyo3::{ exceptions::PyRuntimeError, - ffi, + ffi, intern, prelude::*, sync::GILOnceCell, types::{DerefToPyAny, PyCapsule, PyType}, @@ -100,18 +100,21 @@ pub trait PyBitGeneratorMethods { impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { fn lock(&self) -> PyResult { - let capsule = self.getattr("capsule")?.downcast_into::()?; - let lock = self.getattr("lock")?; + let py = self.py(); + let capsule = self + .getattr(intern!(py, "capsule"))? + .downcast_into::()?; + let lock = self.getattr(intern!(py, "lock"))?; // we’re holding the GIL, so there’s no race condition checking the lock and acquiring it later. - if lock.call_method0("locked")?.extract()? { + if lock.call_method0(intern!(py, "locked"))?.extract()? { return Err(PyRuntimeError::new_err("BitGenerator is already locked")); } - lock.call_method0("acquire")?; + lock.call_method0(intern!(py, "acquire"))?; assert_eq!(capsule.name()?, Some(ffi::c_str!("BitGenerator"))); let ptr = capsule.pointer() as *mut bitgen_t; let Some(non_null) = NonNull::new(ptr) else { - lock.call_method0("release")?; + lock.call_method0(intern!(py, "release"))?; return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); }; Ok(PyBitGeneratorGuard { @@ -146,7 +149,7 @@ impl Drop for PyBitGeneratorGuard { fn drop(&mut self) { // ignore errors. This includes when `try_release` was called manually. let _ = Python::with_gil(|py| -> PyResult<_> { - self.lock.bind(py).call_method0("release")?; + self.lock.bind(py).call_method0(intern!(py, "release"))?; Ok(()) }); } @@ -157,7 +160,7 @@ impl Drop for PyBitGeneratorGuard { impl<'py> PyBitGeneratorGuard { /// Release the lock, allowing for checking for errors. pub fn release(self, py: Python<'py>) -> PyResult<()> { - self.lock.bind(py).call_method0("release")?; + self.lock.bind(py).call_method0(intern!(py, "release"))?; Ok(()) } From e8cd5e8b5fb7fa7d056fb05e04341b64823d28e4 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 15:15:53 +0200 Subject: [PATCH 39/46] docs --- src/random.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index eb966a7e5..f681fe795 100644 --- a/src/random.rs +++ b/src/random.rs @@ -133,6 +133,9 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { } /// [`PyBitGenerator`] lock allowing to access its methods without holding the GIL. +/// +/// Since [dropping](`Drop::drop`) this acquires the GIL, +/// prefer to call [`release`][`PyBitGeneratorGuard::release`] manually to release the lock. pub struct PyBitGeneratorGuard { raw_bitgen: NonNull, /// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. @@ -147,7 +150,7 @@ unsafe impl Send for PyBitGeneratorGuard {} impl Drop for PyBitGeneratorGuard { fn drop(&mut self) { - // ignore errors. This includes when `try_release` was called manually. + // ignore errors. This includes when `release` was called manually. let _ = Python::with_gil(|py| -> PyResult<_> { self.lock.bind(py).call_method0(intern!(py, "release"))?; Ok(()) @@ -219,6 +222,7 @@ mod tests { Ok(bit_generator) } + /* /// Test the primary use case: acquire the lock, release the GIL, then use the lock #[test] fn use_outside_gil() -> PyResult<()> { @@ -231,6 +235,7 @@ mod tests { Ok(()) }) } + */ /// More complex version of primary use case: use from multiple threads #[test] @@ -265,6 +270,7 @@ mod tests { }) } + /* /// Test that the `rand::Rng` APIs work #[test] fn rand() -> PyResult<()> { @@ -291,4 +297,5 @@ mod tests { Ok(()) }) } + */ } From 0868405e4f062d7d4372219ced0ef83d78760002 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 15:17:38 +0200 Subject: [PATCH 40/46] more doc --- src/random.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index f681fe795..10854545a 100644 --- a/src/random.rs +++ b/src/random.rs @@ -17,7 +17,10 @@ //! let random_number = Python::with_gil(|py| -> PyResult<_> { //! let mut bitgen = default_bit_gen(py)?.lock()?; //! // use bitgen without holding the GIL -//! Ok(py.allow_threads(|| bitgen.next_uint64())) +//! let r = py.allow_threads(|| bitgen.next_uint64())?; +//! // release the lock manually while holding the GIL again +//! bitgen.release(py)?; +//! Ok(r) //! })?; //! # Ok::<(), PyErr>(()) //! ``` @@ -40,6 +43,7 @@ //! if bitgen.random_ratio(1, 1_000_000) { //! println!("a sure thing"); //! } +//! bitgen.release(py)?; //! Ok(()) //! })?; //! # Ok::<(), PyErr>(()) From 8667203624ffcd1bbaa47cc7033aeba8094fe840 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 15:45:39 +0200 Subject: [PATCH 41/46] clean up tests --- src/random.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/random.rs b/src/random.rs index 10854545a..21e4986df 100644 --- a/src/random.rs +++ b/src/random.rs @@ -17,7 +17,7 @@ //! let random_number = Python::with_gil(|py| -> PyResult<_> { //! let mut bitgen = default_bit_gen(py)?.lock()?; //! // use bitgen without holding the GIL -//! let r = py.allow_threads(|| bitgen.next_uint64())?; +//! let r = py.allow_threads(|| bitgen.next_u64()); //! // release the lock manually while holding the GIL again //! bitgen.release(py)?; //! Ok(r) @@ -175,6 +175,7 @@ impl<'py> PyBitGeneratorGuard { pub fn next_u64(&mut self) -> u64 { unsafe { let bitgen = self.raw_bitgen.as_ptr(); + debug_assert_ne!((*bitgen).state, std::ptr::null_mut()); ((*bitgen).next_uint64)((*bitgen).state) } } @@ -182,6 +183,7 @@ impl<'py> PyBitGeneratorGuard { pub fn next_u32(&mut self) -> u32 { unsafe { let bitgen = self.raw_bitgen.as_ptr(); + debug_assert_ne!((*bitgen).state, std::ptr::null_mut()); ((*bitgen).next_uint32)((*bitgen).state) } } @@ -189,6 +191,7 @@ impl<'py> PyBitGeneratorGuard { pub fn next_double(&mut self) -> libc::c_double { unsafe { let bitgen = self.raw_bitgen.as_ptr(); + debug_assert_ne!((*bitgen).state, std::ptr::null_mut()); ((*bitgen).next_double)((*bitgen).state) } } @@ -196,6 +199,7 @@ impl<'py> PyBitGeneratorGuard { pub fn next_raw(&mut self) -> u64 { unsafe { let bitgen = self.raw_bitgen.as_ptr(); + debug_assert_ne!((*bitgen).state, std::ptr::null_mut()); ((*bitgen).next_raw)((*bitgen).state) } } @@ -226,7 +230,6 @@ mod tests { Ok(bit_generator) } - /* /// Test the primary use case: acquire the lock, release the GIL, then use the lock #[test] fn use_outside_gil() -> PyResult<()> { @@ -239,9 +242,9 @@ mod tests { Ok(()) }) } - */ /// More complex version of primary use case: use from multiple threads + #[cfg(feature = "rand_core")] #[test] fn use_parallel() -> PyResult<()> { use crate::array::{PyArray2, PyArrayMethods as _}; @@ -274,8 +277,8 @@ mod tests { }) } - /* /// Test that the `rand::Rng` APIs work + #[cfg(feature = "rand_core")] #[test] fn rand() -> PyResult<()> { use rand::Rng as _; @@ -301,5 +304,4 @@ mod tests { Ok(()) }) } - */ } From 1fd7bb57d69d74050cccf191978f98eb38b77abb Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 17:43:05 +0200 Subject: [PATCH 42/46] no let-else --- src/random.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/random.rs b/src/random.rs index 21e4986df..8b46398f1 100644 --- a/src/random.rs +++ b/src/random.rs @@ -117,9 +117,12 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { assert_eq!(capsule.name()?, Some(ffi::c_str!("BitGenerator"))); let ptr = capsule.pointer() as *mut bitgen_t; - let Some(non_null) = NonNull::new(ptr) else { - lock.call_method0(intern!(py, "release"))?; - return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); + let non_null = match NonNull::new(ptr) { + Some(non_null) => non_null, + None => { + lock.call_method0(intern!(py, "release"))?; + return Err(PyRuntimeError::new_err("Invalid BitGenerator capsule")); + } }; Ok(PyBitGeneratorGuard { raw_bitgen: non_null, From 3913171806bc58f80ddb5e1361f25fd2dc883801 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 18:55:29 +0200 Subject: [PATCH 43/46] use GILOnceCell::import --- src/random.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/random.rs b/src/random.rs index 8b46398f1..5bee3fb15 100644 --- a/src/random.rs +++ b/src/random.rs @@ -82,16 +82,8 @@ unsafe impl PyTypeInfo for PyBitGenerator { fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject { static CLS: GILOnceCell> = GILOnceCell::new(); let cls = CLS - .get_or_try_init::<_, PyErr>(py, || { - Ok(py - .import("numpy.random")? - .getattr("BitGenerator")? - .downcast_into::()? - .unbind()) - }) - .expect("Failed to get BitGenerator type object") - .clone_ref(py) - .into_bound(py); + .import(py, "numpy.random", "BitGenerator") + .expect("Failed to get BitGenerator type object"); cls.as_type_ptr() } } From 7bc0be8e659f2112b9a9c6474770a5912a02bae6 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 19:04:27 +0200 Subject: [PATCH 44/46] add `released` attr --- src/random.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/random.rs b/src/random.rs index 5bee3fb15..7a0832eb2 100644 --- a/src/random.rs +++ b/src/random.rs @@ -118,6 +118,7 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { }; Ok(PyBitGeneratorGuard { raw_bitgen: non_null, + released: false, _capsule: capsule.unbind(), lock: lock.unbind(), }) @@ -137,6 +138,8 @@ impl<'py> TryFrom<&Bound<'py, PyBitGenerator>> for PyBitGeneratorGuard { /// prefer to call [`release`][`PyBitGeneratorGuard::release`] manually to release the lock. pub struct PyBitGeneratorGuard { raw_bitgen: NonNull, + /// Whether this guard has been manually released. + released: bool, /// This field makes sure the `raw_bitgen` inside the capsule doesn’t get deallocated. _capsule: Py, /// This lock makes sure no other threads try to use the BitGenerator while we do. @@ -149,7 +152,10 @@ unsafe impl Send for PyBitGeneratorGuard {} impl Drop for PyBitGeneratorGuard { fn drop(&mut self) { - // ignore errors. This includes when `release` was called manually. + if self.released { + return; + } + // ignore errors because `drop` can’t fail let _ = Python::with_gil(|py| -> PyResult<_> { self.lock.bind(py).call_method0(intern!(py, "release"))?; Ok(()) @@ -161,7 +167,8 @@ impl Drop for PyBitGeneratorGuard { // 2. We hold the `BitGenerator.capsule`, so it can’t be deallocated. impl<'py> PyBitGeneratorGuard { /// Release the lock, allowing for checking for errors. - pub fn release(self, py: Python<'py>) -> PyResult<()> { + pub fn release(mut self, py: Python<'py>) -> PyResult<()> { + self.released = true; // only ever read by drop at the end of a scope (like this one). self.lock.bind(py).call_method0(intern!(py, "release"))?; Ok(()) } @@ -267,7 +274,11 @@ mod tests { }) }); - std::mem::drop(bitgen); + Arc::into_inner(bitgen) + .unwrap() + .into_inner() + .unwrap() + .release(py)?; Ok(()) }) } From 8caf05458e6469fce38bab13a6a846545ecbf0bd Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 10 Jun 2025 19:06:05 +0200 Subject: [PATCH 45/46] f64 --- src/random.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/random.rs b/src/random.rs index 7a0832eb2..2837b3afe 100644 --- a/src/random.rs +++ b/src/random.rs @@ -190,7 +190,7 @@ impl<'py> PyBitGeneratorGuard { } } /// Returns the next random double. - pub fn next_double(&mut self) -> libc::c_double { + pub fn next_double(&mut self) -> f64 { unsafe { let bitgen = self.raw_bitgen.as_ptr(); debug_assert_ne!((*bitgen).state, std::ptr::null_mut()); From d8b62ac75739d31172318b9759a0d436e7e0041d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Sun, 15 Jun 2025 19:00:31 +0200 Subject: [PATCH 46/46] correct locking --- src/random.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/random.rs b/src/random.rs index 2837b3afe..db0012a24 100644 --- a/src/random.rs +++ b/src/random.rs @@ -100,14 +100,18 @@ impl<'py> PyBitGeneratorMethods for Bound<'py, PyBitGenerator> { let capsule = self .getattr(intern!(py, "capsule"))? .downcast_into::()?; + assert_eq!(capsule.name()?, Some(ffi::c_str!("BitGenerator"))); let lock = self.getattr(intern!(py, "lock"))?; - // we’re holding the GIL, so there’s no race condition checking the lock and acquiring it later. - if lock.call_method0(intern!(py, "locked"))?.extract()? { - return Err(PyRuntimeError::new_err("BitGenerator is already locked")); + // Acquire the lock in non-blocking mode or return an error + if !lock + .call_method(intern!(py, "acquire"), (false,), None)? + .extract()? + { + return Err(PyRuntimeError::new_err( + "Failed to acquire BitGenerator lock", + )); } - lock.call_method0(intern!(py, "acquire"))?; - - assert_eq!(capsule.name()?, Some(ffi::c_str!("BitGenerator"))); + // Return the guard or release the lock if the capsule is invalid let ptr = capsule.pointer() as *mut bitgen_t; let non_null = match NonNull::new(ptr) { Some(non_null) => non_null,