Skip to content

Commit 2afaa93

Browse files
committed
Add support for arbitrary arrays
1 parent eaf516d commit 2afaa93

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

build.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,27 @@ fn abi3_without_interpreter() -> Result<()> {
902902
Ok(())
903903
}
904904

905+
fn rustc_minor_version() -> Option<u32> {
906+
let rustc = env::var_os("RUSTC")?;
907+
let output = Command::new(rustc).arg("--version").output().ok()?;
908+
let version = core::str::from_utf8(&output.stdout).ok()?;
909+
let mut pieces = version.split('.');
910+
if pieces.next() != Some("rustc 1") {
911+
return None;
912+
}
913+
pieces.next()?.parse().ok()
914+
}
915+
916+
fn manage_min_const_generics() {
917+
let rustc_minor_version = match rustc_minor_version() {
918+
Some(inner) => inner,
919+
None => return,
920+
};
921+
if rustc_minor_version >= 51 {
922+
println!("cargo:rustc-cfg=min_const_generics");
923+
}
924+
}
925+
905926
fn main() -> Result<()> {
906927
// If PYO3_NO_PYTHON is set with abi3, we can build PyO3 without calling Python.
907928
// We only check for the abi3-py3{ABI3_MAX_MINOR} because lower versions depend on it.
@@ -961,5 +982,7 @@ fn main() -> Result<()> {
961982
println!("cargo:rustc-cfg=__pyo3_ci");
962983
}
963984

985+
manage_min_const_generics();
986+
964987
Ok(())
965988
}

src/types/list.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ where
178178
}
179179
}
180180

181+
#[cfg(min_const_generics)]
181182
macro_rules! array_impls {
182183
($($N:expr),+) => {
183184
$(
@@ -193,11 +194,22 @@ macro_rules! array_impls {
193194
}
194195
}
195196

197+
#[cfg(min_const_generics)]
196198
array_impls!(
197199
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
198200
26, 27, 28, 29, 30, 31, 32
199201
);
200202

203+
#[cfg(not(min_const_generics))]
204+
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
205+
where
206+
T: ToPyObject,
207+
{
208+
fn into_py(self, py: Python) -> PyObject {
209+
self.as_ref().to_object(py)
210+
}
211+
}
212+
201213
impl<T> ToPyObject for Vec<T>
202214
where
203215
T: ToPyObject,

src/types/mod.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,40 @@ mod slice;
244244
mod string;
245245
mod tuple;
246246
mod typeobject;
247+
248+
#[cfg(min_const_generics)]
249+
struct ArrayGuard<T, const N: usize> {
250+
dst: *mut T,
251+
initialized: usize,
252+
}
253+
254+
#[cfg(min_const_generics)]
255+
impl<T, const N: usize> Drop for ArrayGuard<T, N> {
256+
fn drop(&mut self) {
257+
debug_assert!(self.initialized <= N);
258+
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
259+
unsafe {
260+
core::ptr::drop_in_place(initialized_part);
261+
}
262+
}
263+
}
264+
265+
#[cfg(min_const_generics)]
266+
fn try_create_array<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
267+
where
268+
F: FnMut(usize) -> Result<T, E>,
269+
{
270+
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
271+
let mut guard: ArrayGuard<T, N> = ArrayGuard {
272+
dst: array.as_mut_ptr() as _,
273+
initialized: 0,
274+
};
275+
unsafe {
276+
for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() {
277+
core::ptr::write(value_ptr, cb(idx)?);
278+
guard.initialized += 1;
279+
}
280+
core::mem::forget(guard);
281+
Ok(array.assume_init())
282+
}
283+
}

src/types/sequence.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ impl PySequence {
257257
}
258258
}
259259

260+
#[cfg(not(min_const_generics))]
260261
macro_rules! array_impls {
261262
($($N:expr),+) => {
262263
$(
@@ -305,11 +306,46 @@ macro_rules! array_impls {
305306
}
306307
}
307308

309+
#[cfg(not(min_const_generics))]
308310
array_impls!(
309311
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
310312
26, 27, 28, 29, 30, 31, 32
311313
);
312314

315+
#[cfg(all(min_const_generics, not(feature = "nightly")))]
316+
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
317+
where
318+
T: FromPyObject<'a>,
319+
{
320+
#[cfg(not(feature = "nightly"))]
321+
fn extract(obj: &'a PyAny) -> PyResult<Self> {
322+
create_array_from_obj(obj)
323+
}
324+
325+
#[cfg(feature = "nightly")]
326+
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
327+
create_array_from_obj(obj)
328+
}
329+
}
330+
331+
#[cfg(all(min_const_generics, feature = "nightly"))]
332+
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
333+
where
334+
for<'a> T: FromPyObject<'a> + crate::buffer::Element,
335+
{
336+
fn extract(obj: &'source PyAny) -> PyResult<Self> {
337+
let mut array = create_array_from_obj(obj)?;
338+
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
339+
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
340+
buf.release(obj.py());
341+
return Ok(array);
342+
}
343+
buf.release(obj.py());
344+
}
345+
Ok(array)
346+
}
347+
}
348+
313349
impl<'a, T> FromPyObject<'a> for Vec<T>
314350
where
315351
T: FromPyObject<'a>,
@@ -345,6 +381,21 @@ where
345381
}
346382
}
347383

384+
#[cfg(min_const_generics)]
385+
fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
386+
where
387+
T: FromPyObject<'s>,
388+
{
389+
let seq = <PySequence as PyTryFrom>::try_from(obj)?;
390+
crate::types::try_create_array(|idx| {
391+
seq.get_item(idx as isize)
392+
.map_err(|_| {
393+
exceptions::PyBufferError::new_err("Slice length does not match buffer length.")
394+
})?
395+
.extract::<T>()
396+
})
397+
}
398+
348399
fn extract_sequence<'s, T>(obj: &'s PyAny) -> PyResult<Vec<T>>
349400
where
350401
T: FromPyObject<'s>,
@@ -357,6 +408,7 @@ where
357408
Ok(v)
358409
}
359410

411+
#[cfg(not(min_const_generics))]
360412
fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
361413
where
362414
T: FromPyObject<'s>,
@@ -706,6 +758,7 @@ mod test {
706758
assert!(v == [1, 2, 3, 4]);
707759
}
708760

761+
#[cfg(not(min_const_generics))]
709762
#[test]
710763
fn test_extract_bytearray_to_array() {
711764
let gil = Python::acquire_gil();
@@ -718,6 +771,23 @@ mod test {
718771
assert!(&v == b"abc");
719772
}
720773

774+
#[cfg(min_const_generics)]
775+
#[test]
776+
fn test_extract_bytearray_to_array() {
777+
let gil = Python::acquire_gil();
778+
let py = gil.python();
779+
let v: [u8; 33] = py
780+
.eval(
781+
"bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
782+
None,
783+
None,
784+
)
785+
.unwrap()
786+
.extract()
787+
.unwrap();
788+
assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
789+
}
790+
721791
#[test]
722792
fn test_extract_bytearray_to_vec() {
723793
let gil = Python::acquire_gil();

0 commit comments

Comments
 (0)