Skip to content

Commit 7ead166

Browse files
committed
array: safer implementation of try_create_array
1 parent 9bc5089 commit 7ead166

File tree

4 files changed

+83
-95
lines changed

4 files changed

+83
-95
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1313
- Support PyPy 3.7. [#1538](https://github.com/PyO3/pyo3/pull/1538)
1414

1515
### Added
16-
- Add conversion for `[T; N]` for all `N` on Rust 1.51 and up. [#1128](https://github.com/PyO3/pyo3/pull/1128)
16+
- Add conversions for `[T; N]` for all `N` on Rust 1.51 and up. [#1128](https://github.com/PyO3/pyo3/pull/1128)
1717
- Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379)
1818
- Add `#[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411)
1919
- Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425)

src/conversions/array.rs

Lines changed: 82 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1-
use crate::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject};
1+
use crate::{
2+
exceptions, FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, PyTryFrom, Python,
3+
ToPyObject,
4+
};
25

36
#[cfg(not(min_const_generics))]
47
macro_rules! array_impls {
58
($($N:expr),+) => {
69
$(
10+
impl<T> IntoPy<PyObject> for [T; $N]
11+
where
12+
T: ToPyObject
13+
{
14+
fn into_py(self, py: Python) -> PyObject {
15+
self.as_ref().to_object(py)
16+
}
17+
}
18+
719
impl<'a, T> FromPyObject<'a> for [T; $N]
820
where
921
T: Copy + Default + FromPyObject<'a>,
@@ -55,6 +67,16 @@ array_impls!(
5567
26, 27, 28, 29, 30, 31, 32
5668
);
5769

70+
#[cfg(min_const_generics)]
71+
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
72+
where
73+
T: ToPyObject,
74+
{
75+
fn into_py(self, py: Python) -> PyObject {
76+
self.as_ref().to_object(py)
77+
}
78+
}
79+
5880
#[cfg(min_const_generics)]
5981
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
6082
where
@@ -71,60 +93,27 @@ where
7193
}
7294
}
7395

74-
#[cfg(not(min_const_generics))]
75-
macro_rules! array_impls {
76-
($($N:expr),+) => {
77-
$(
78-
impl<T> IntoPy<PyObject> for [T; $N]
79-
where
80-
T: ToPyObject
81-
{
82-
fn into_py(self, py: Python) -> PyObject {
83-
self.as_ref().to_object(py)
84-
}
85-
}
86-
)+
87-
}
88-
}
89-
90-
#[cfg(not(min_const_generics))]
91-
array_impls!(
92-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
93-
26, 27, 28, 29, 30, 31, 32
94-
);
95-
96-
#[cfg(min_const_generics)]
97-
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
98-
where
99-
T: ToPyObject,
100-
{
101-
fn into_py(self, py: Python) -> PyObject {
102-
self.as_ref().to_object(py)
103-
}
104-
}
105-
10696
#[cfg(all(min_const_generics, feature = "nightly"))]
10797
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
10898
where
109-
for<'a> T: FromPyObject<'a> + crate::buffer::Element,
99+
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
110100
{
111101
fn extract(obj: &'source PyAny) -> PyResult<Self> {
112-
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
102+
use crate::{AsPyPointer, PyNativeType};
103+
let mut array = [T::default(); N];
113104
// first try buffer protocol
114105
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
115106
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
116107
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
117108
buf.release(obj.py());
118-
// SAFETY: The array should be fully filled by `copy_to_slice`
119-
return Ok(unsafe { array.assume_init() });
109+
return Ok(array);
120110
}
121111
buf.release(obj.py());
122112
}
123113
}
124114
// fall back to sequence protocol
125115
_extract_sequence_into_slice(obj, &mut array)?;
126-
// SAFETY: The array should be fully filled by `_extract_sequence_into_slice`
127-
Ok(unsafe { array.assume_init() })
116+
Ok(array)
128117
}
129118
}
130119

@@ -135,102 +124,110 @@ where
135124
{
136125
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
137126
let expected_len = seq.len()? as usize;
138-
let mut counter = 0;
139-
try_create_array(&mut counter, |idx| {
127+
array_try_from_fn(|idx| {
140128
seq.get_item(idx as isize)
141-
.map_err(|_| crate::utils::invalid_sequence_length(expected_len, idx + 1))?
129+
.map_err(|_| invalid_sequence_length(expected_len, idx + 1))?
142130
.extract::<T>()
143131
})
144132
}
145133

146-
fn _extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
147-
where
148-
T: FromPyObject<'s>,
149-
{
150-
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
151-
let expected_len = seq.len()? as usize;
152-
if expected_len != slice.len() {
153-
return Err(crate::utils::invalid_sequence_length(
154-
expected_len,
155-
slice.len(),
156-
));
157-
}
158-
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
159-
*value = item?.extract::<T>()?;
160-
}
161-
Ok(())
162-
}
163-
134+
// TODO use std::array::try_from_fn, if that stabilises:
135+
// (https://github.com/rust-lang/rust/pull/75644)
164136
#[cfg(min_const_generics)]
165-
fn try_create_array<E, F, T, const N: usize>(counter: &mut usize, mut cb: F) -> Result<[T; N], E>
137+
fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
166138
where
167139
F: FnMut(usize) -> Result<T, E>,
168140
{
169141
// Helper to safely create arrays since the standard library doesn't
170142
// provide one yet. Shouldn't be necessary in the future.
171-
struct ArrayGuard<'a, T, const N: usize> {
143+
struct ArrayGuard<T, const N: usize> {
172144
dst: *mut T,
173-
initialized: &'a mut usize,
145+
initialized: usize,
174146
}
175147

176-
impl<T, const N: usize> Drop for ArrayGuard<'_, T, N> {
148+
impl<T, const N: usize> Drop for ArrayGuard<T, N> {
177149
fn drop(&mut self) {
178-
debug_assert!(*self.initialized <= N);
179-
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, *self.initialized);
150+
debug_assert!(self.initialized <= N);
151+
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
180152
unsafe {
181153
core::ptr::drop_in_place(initialized_part);
182154
}
183155
}
184156
}
185157

158+
// [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
159+
// APIs which would make this easier.
186160
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
187-
let guard: ArrayGuard<T, N> = ArrayGuard {
161+
let mut guard: ArrayGuard<T, N> = ArrayGuard {
188162
dst: array.as_mut_ptr() as _,
189-
initialized: counter,
163+
initialized: 0,
190164
};
191165
unsafe {
192-
for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() {
193-
core::ptr::write(value_ptr, cb(idx)?);
194-
*guard.initialized += 1;
166+
let mut value_ptr = array.as_mut_ptr() as *mut T;
167+
for i in 0..N {
168+
core::ptr::write(value_ptr, cb(i)?);
169+
value_ptr = value_ptr.offset(1);
170+
guard.initialized += 1;
195171
}
196172
core::mem::forget(guard);
197173
Ok(array.assume_init())
198174
}
199175
}
200176

177+
fn _extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
178+
where
179+
T: FromPyObject<'s>,
180+
{
181+
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
182+
let expected_len = seq.len()? as usize;
183+
if expected_len != slice.len() {
184+
return Err(invalid_sequence_length(expected_len, slice.len()));
185+
}
186+
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
187+
*value = item?.extract::<T>()?;
188+
}
189+
Ok(())
190+
}
191+
192+
pub fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
193+
exceptions::PyValueError::new_err(format!(
194+
"expected a sequence of length {} (got {})",
195+
expected, actual
196+
))
197+
}
198+
201199
#[cfg(test)]
202200
mod test {
203201
use crate::Python;
204202
#[cfg(min_const_generics)]
205203
use std::{
206204
panic,
207-
sync::{Arc, Mutex},
208-
thread::sleep,
209-
time,
205+
sync::atomic::{AtomicUsize, Ordering},
210206
};
211207

212208
#[cfg(min_const_generics)]
213209
#[test]
214-
fn try_create_array() {
215-
#[allow(clippy::mutex_atomic)]
216-
let counter = Arc::new(Mutex::new(0));
217-
let counter_unwind = Arc::clone(&counter);
210+
fn array_try_from_fn() {
211+
static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
212+
struct CountDrop;
213+
impl Drop for CountDrop {
214+
fn drop(&mut self) {
215+
DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
216+
}
217+
}
218218
let _ = catch_unwind_silent(move || {
219-
let mut locked = counter_unwind.lock().unwrap();
220-
let _: Result<[i32; 4], _> = super::try_create_array(&mut *locked, |idx| {
219+
let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
221220
if idx == 2 {
222221
panic!("peek a boo");
223222
}
224-
Ok::<_, ()>(1)
223+
Ok(CountDrop)
225224
});
226225
});
227-
sleep(time::Duration::from_secs(2));
228-
assert_eq!(*counter.lock().unwrap_err().into_inner(), 2);
226+
assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
229227
}
230228

231-
#[cfg(not(min_const_generics))]
232229
#[test]
233-
fn test_extract_bytearray_to_array() {
230+
fn test_extract_small_bytearray_to_array() {
234231
let gil = Python::acquire_gil();
235232
let py = gil.python();
236233
let v: [u8; 3] = py

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ pub mod pyclass_slots;
186186
mod python;
187187
pub mod type_object;
188188
pub mod types;
189-
mod utils;
190189

191190
#[cfg(feature = "serde")]
192191
pub mod serde;

src/utils.rs

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)