Skip to content

Commit 0e35860

Browse files
committed
Add support for arbitrary arrays
1 parent eaf516d commit 0e35860

File tree

3 files changed

+80
-79
lines changed

3 files changed

+80
-79
lines changed

src/types/list.rs

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,15 @@ where
178178
}
179179
}
180180

181-
macro_rules! array_impls {
182-
($($N:expr),+) => {
183-
$(
184-
impl<T> IntoPy<PyObject> for [T; $N]
185-
where
186-
T: ToPyObject
187-
{
188-
fn into_py(self, py: Python) -> PyObject {
189-
self.as_ref().to_object(py)
190-
}
191-
}
192-
)+
181+
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
182+
where
183+
T: ToPyObject,
184+
{
185+
fn into_py(self, py: Python) -> PyObject {
186+
self.as_ref().to_object(py)
193187
}
194188
}
195189

196-
array_impls!(
197-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
198-
26, 27, 28, 29, 30, 31, 32
199-
);
200-
201190
impl<T> ToPyObject for Vec<T>
202191
where
203192
T: ToPyObject,

src/types/mod.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,37 @@ mod slice;
244244
mod string;
245245
mod tuple;
246246
mod typeobject;
247+
248+
struct ArrayGuard<T, const N: usize> {
249+
dst: *mut T,
250+
initialized: usize,
251+
}
252+
253+
impl<T, const N: usize> Drop for ArrayGuard<T, N> {
254+
fn drop(&mut self) {
255+
debug_assert!(self.initialized <= N);
256+
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
257+
unsafe {
258+
core::ptr::drop_in_place(initialized_part);
259+
}
260+
}
261+
}
262+
263+
fn try_create_array<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
264+
where
265+
F: FnMut(usize) -> Result<T, E>,
266+
{
267+
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
268+
let mut guard: ArrayGuard<T, N> = ArrayGuard {
269+
dst: array.as_mut_ptr() as _,
270+
initialized: 0,
271+
};
272+
unsafe {
273+
for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() {
274+
core::ptr::write(value_ptr, cb(idx)?);
275+
guard.initialized += 1;
276+
}
277+
core::mem::forget(guard);
278+
Ok(array.assume_init())
279+
}
280+
}

src/types/sequence.rs

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -257,59 +257,39 @@ impl PySequence {
257257
}
258258
}
259259

260-
macro_rules! array_impls {
261-
($($N:expr),+) => {
262-
$(
263-
impl<'a, T> FromPyObject<'a> for [T; $N]
264-
where
265-
T: Copy + Default + FromPyObject<'a>,
266-
{
267-
#[cfg(not(feature = "nightly"))]
268-
fn extract(obj: &'a PyAny) -> PyResult<Self> {
269-
let mut array = [T::default(); $N];
270-
extract_sequence_into_slice(obj, &mut array)?;
271-
Ok(array)
272-
}
260+
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
261+
where
262+
T: FromPyObject<'a>,
263+
{
264+
#[cfg(not(feature = "nightly"))]
265+
fn extract(obj: &'a PyAny) -> PyResult<Self> {
266+
create_array_from_obj(obj)
267+
}
273268

274-
#[cfg(feature = "nightly")]
275-
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
276-
let mut array = [T::default(); $N];
277-
extract_sequence_into_slice(obj, &mut array)?;
278-
Ok(array)
279-
}
280-
}
269+
#[cfg(feature = "nightly")]
270+
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
271+
create_array_from_obj(obj)
272+
}
273+
}
281274

282-
#[cfg(feature = "nightly")]
283-
impl<'source, T> FromPyObject<'source> for [T; $N]
284-
where
285-
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
286-
{
287-
fn extract(obj: &'source PyAny) -> PyResult<Self> {
288-
let mut array = [T::default(); $N];
289-
// first try buffer protocol
290-
if unsafe { ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
291-
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
292-
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
293-
buf.release(obj.py());
294-
return Ok(array);
295-
}
296-
buf.release(obj.py());
297-
}
298-
}
299-
// fall back to sequence protocol
300-
extract_sequence_into_slice(obj, &mut array)?;
301-
Ok(array)
302-
}
275+
#[cfg(feature = "nightly")]
276+
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
277+
where
278+
for<'a> T: FromPyObject<'a> + crate::buffer::Element,
279+
{
280+
fn extract(obj: &'source PyAny) -> PyResult<Self> {
281+
let mut array = create_array_from_obj(obj)?;
282+
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
283+
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
284+
buf.release(obj.py());
285+
return Ok(array);
303286
}
304-
)+
287+
buf.release(obj.py());
288+
}
289+
Ok(array)
305290
}
306291
}
307292

308-
array_impls!(
309-
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
310-
26, 27, 28, 29, 30, 31, 32
311-
);
312-
313293
impl<'a, T> FromPyObject<'a> for Vec<T>
314294
where
315295
T: FromPyObject<'a>,
@@ -345,32 +325,30 @@ where
345325
}
346326
}
347327

348-
fn extract_sequence<'s, T>(obj: &'s PyAny) -> PyResult<Vec<T>>
328+
fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
349329
where
350330
T: FromPyObject<'s>,
351331
{
352332
let seq = <PySequence as PyTryFrom>::try_from(obj)?;
353-
let mut v = Vec::with_capacity(seq.len().unwrap_or(0) as usize);
354-
for item in seq.iter()? {
355-
v.push(item?.extract::<T>()?);
356-
}
357-
Ok(v)
333+
crate::types::try_create_array(|idx| {
334+
seq.get_item(idx as isize)
335+
.map_err(|_| {
336+
exceptions::PyBufferError::new_err("Slice length does not match buffer length.")
337+
})?
338+
.extract::<T>()
339+
})
358340
}
359341

360-
fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
342+
fn extract_sequence<'s, T>(obj: &'s PyAny) -> PyResult<Vec<T>>
361343
where
362344
T: FromPyObject<'s>,
363345
{
364346
let seq = <PySequence as PyTryFrom>::try_from(obj)?;
365-
if seq.len()? as usize != slice.len() {
366-
return Err(exceptions::PyBufferError::new_err(
367-
"Slice length does not match buffer length.",
368-
));
369-
}
370-
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
371-
*value = item?.extract::<T>()?;
347+
let mut v = Vec::with_capacity(seq.len().unwrap_or(0) as usize);
348+
for item in seq.iter()? {
349+
v.push(item?.extract::<T>()?);
372350
}
373-
Ok(())
351+
Ok(v)
374352
}
375353

376354
impl<'v> PyTryFrom<'v> for PySequence {

0 commit comments

Comments
 (0)