Skip to content

Commit 6e30c6e

Browse files
Merge pull request #315 from rust-lang/scatter-gather-ptr
Scatter/gather for pointers
2 parents 35c60ce + 7e614f0 commit 6e30c6e

File tree

1 file changed

+119
-3
lines changed

1 file changed

+119
-3
lines changed

crates/core_simd/src/vector.rs

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,71 @@ where
364364
let base_ptr = Simd::<*const T, LANES>::splat(slice.as_ptr());
365365
// Ferris forgive me, I have done pointer arithmetic here.
366366
let ptrs = base_ptr.wrapping_add(idxs);
367-
// Safety: The ptrs have been bounds-masked to prevent memory-unsafe reads insha'allah
368-
unsafe { intrinsics::simd_gather(or, ptrs, enable.to_int()) }
367+
// Safety: The caller is responsible for determining the indices are okay to read
368+
unsafe { Self::gather_select_ptr(ptrs, enable, or) }
369+
}
370+
371+
/// Read pointers elementwise into a SIMD vector.
372+
///
373+
/// # Safety
374+
///
375+
/// Each read must satisfy the same conditions as [`core::ptr::read`].
376+
///
377+
/// # Example
378+
/// ```
379+
/// # #![feature(portable_simd)]
380+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
381+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
382+
/// # use simd::{Simd, SimdConstPtr};
383+
/// let values = [6, 2, 4, 9];
384+
/// let offsets = Simd::from_array([1, 0, 0, 3]);
385+
/// let source = Simd::splat(values.as_ptr()).wrapping_add(offsets);
386+
/// let gathered = unsafe { Simd::gather_ptr(source) };
387+
/// assert_eq!(gathered, Simd::from_array([2, 6, 6, 9]));
388+
/// ```
389+
#[must_use]
390+
#[inline]
391+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
392+
pub unsafe fn gather_ptr(source: Simd<*const T, LANES>) -> Self
393+
where
394+
T: Default,
395+
{
396+
// TODO: add an intrinsic that doesn't use a passthru vector, and remove the T: Default bound
397+
// Safety: The caller is responsible for upholding all invariants
398+
unsafe { Self::gather_select_ptr(source, Mask::splat(true), Self::default()) }
399+
}
400+
401+
/// Conditionally read pointers elementwise into a SIMD vector.
402+
/// The mask `enable`s all `true` lanes and disables all `false` lanes.
403+
/// If a lane is disabled, the lane is selected from the `or` vector and no read is performed.
404+
///
405+
/// # Safety
406+
///
407+
/// Enabled lanes must satisfy the same conditions as [`core::ptr::read`].
408+
///
409+
/// # Example
410+
/// ```
411+
/// # #![feature(portable_simd)]
412+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
413+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
414+
/// # use simd::{Mask, Simd, SimdConstPtr};
415+
/// let values = [6, 2, 4, 9];
416+
/// let enable = Mask::from_array([true, true, false, true]);
417+
/// let offsets = Simd::from_array([1, 0, 0, 3]);
418+
/// let source = Simd::splat(values.as_ptr()).wrapping_add(offsets);
419+
/// let gathered = unsafe { Simd::gather_select_ptr(source, enable, Simd::splat(0)) };
420+
/// assert_eq!(gathered, Simd::from_array([2, 6, 0, 9]));
421+
/// ```
422+
#[must_use]
423+
#[inline]
424+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
425+
pub unsafe fn gather_select_ptr(
426+
source: Simd<*const T, LANES>,
427+
enable: Mask<isize, LANES>,
428+
or: Self,
429+
) -> Self {
430+
// Safety: The caller is responsible for upholding all invariants
431+
unsafe { intrinsics::simd_gather(or, source, enable.to_int()) }
369432
}
370433

371434
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
@@ -473,10 +536,63 @@ where
473536
// Ferris forgive me, I have done pointer arithmetic here.
474537
let ptrs = base_ptr.wrapping_add(idxs);
475538
// The ptrs have been bounds-masked to prevent memory-unsafe writes insha'allah
476-
intrinsics::simd_scatter(self, ptrs, enable.to_int())
539+
self.scatter_select_ptr(ptrs, enable);
477540
// Cleared ☢️ *mut T Zone
478541
}
479542
}
543+
544+
/// Write pointers elementwise into a SIMD vector.
545+
///
546+
/// # Safety
547+
///
548+
/// Each write must satisfy the same conditions as [`core::ptr::write`].
549+
///
550+
/// # Example
551+
/// ```
552+
/// # #![feature(portable_simd)]
553+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
554+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
555+
/// # use simd::{Simd, SimdMutPtr};
556+
/// let mut values = [0; 4];
557+
/// let offset = Simd::from_array([3, 2, 1, 0]);
558+
/// let ptrs = Simd::splat(values.as_mut_ptr()).wrapping_add(offset);
559+
/// unsafe { Simd::from_array([6, 3, 5, 7]).scatter_ptr(ptrs); }
560+
/// assert_eq!(values, [7, 5, 3, 6]);
561+
/// ```
562+
#[inline]
563+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
564+
pub unsafe fn scatter_ptr(self, dest: Simd<*mut T, LANES>) {
565+
// Safety: The caller is responsible for upholding all invariants
566+
unsafe { self.scatter_select_ptr(dest, Mask::splat(true)) }
567+
}
568+
569+
/// Conditionally write pointers elementwise into a SIMD vector.
570+
/// The mask `enable`s all `true` lanes and disables all `false` lanes.
571+
/// If a lane is disabled, the write to that lane is skipped.
572+
///
573+
/// # Safety
574+
///
575+
/// Enabled lanes must satisfy the same conditions as [`core::ptr::write`].
576+
///
577+
/// # Example
578+
/// ```
579+
/// # #![feature(portable_simd)]
580+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
581+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
582+
/// # use simd::{Mask, Simd, SimdMutPtr};
583+
/// let mut values = [0; 4];
584+
/// let offset = Simd::from_array([3, 2, 1, 0]);
585+
/// let ptrs = Simd::splat(values.as_mut_ptr()).wrapping_add(offset);
586+
/// let enable = Mask::from_array([true, true, false, false]);
587+
/// unsafe { Simd::from_array([6, 3, 5, 7]).scatter_select_ptr(ptrs, enable); }
588+
/// assert_eq!(values, [0, 0, 3, 6]);
589+
/// ```
590+
#[inline]
591+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
592+
pub unsafe fn scatter_select_ptr(self, dest: Simd<*mut T, LANES>, enable: Mask<isize, LANES>) {
593+
// Safety: The caller is responsible for upholding all invariants
594+
unsafe { intrinsics::simd_scatter(self, dest, enable.to_int()) }
595+
}
480596
}
481597

482598
impl<T, const LANES: usize> Copy for Simd<T, LANES>

0 commit comments

Comments
 (0)