Skip to content

Commit fd53445

Browse files
committed
Add pointer scatter/gather
1 parent ecc2875 commit fd53445

File tree

1 file changed

+65
-3
lines changed

1 file changed

+65
-3
lines changed

crates/core_simd/src/vector.rs

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,44 @@ where
364364
let base_ptr = Simd::<*const T, LANES>::splat(slice.as_ptr());
365365
// Ferris forgive me, I have done pointer arithmetic here.
366366
let ptrs = base_ptr.wrapping_add(idxs);
367-
// Safety: The ptrs have been bounds-masked to prevent memory-unsafe reads insha'allah
368-
unsafe { intrinsics::simd_gather(or, ptrs, enable.to_int()) }
367+
// Safety: The caller is responsible for determining the indices are okay to read
368+
unsafe { Self::gather_select_ptr(ptrs, enable, or) }
369+
}
370+
371+
/// Read pointers elementwise into a SIMD vector vector.
372+
///
373+
/// # Safety
374+
///
375+
/// Each read must satisfy the same conditions as [`core::ptr::read`].
376+
#[must_use]
377+
#[inline]
378+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
379+
pub unsafe fn gather_ptr(source: Simd<*const T, LANES>) -> Self
380+
where
381+
T: Default,
382+
{
383+
// TODO: add an intrinsic that doesn't use a passthru vector, and remove the T: Default bound
384+
// Safety: The caller is responsible for upholding all invariants
385+
unsafe { Self::gather_select_ptr(source, Mask::splat(true), Self::default()) }
386+
}
387+
388+
/// Conditionally read pointers elementwise into a SIMD vector vector.
389+
/// The mask `enable`s all `true` lanes and disables all `false` lanes.
390+
/// If a lane is disabled, the lane is selected from the `or` vector and no read is performed.
391+
///
392+
/// # Safety
393+
///
394+
/// Enabled lanes must satisfy the same conditions as [`core::ptr::read`].
395+
#[must_use]
396+
#[inline]
397+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
398+
pub unsafe fn gather_select_ptr(
399+
source: Simd<*const T, LANES>,
400+
enable: Mask<isize, LANES>,
401+
or: Self,
402+
) -> Self {
403+
// Safety: The caller is responsible for upholding all invariants
404+
unsafe { intrinsics::simd_gather(or, source, enable.to_int()) }
369405
}
370406

371407
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
@@ -473,10 +509,36 @@ where
473509
// Ferris forgive me, I have done pointer arithmetic here.
474510
let ptrs = base_ptr.wrapping_add(idxs);
475511
// The ptrs have been bounds-masked to prevent memory-unsafe writes insha'allah
476-
intrinsics::simd_scatter(self, ptrs, enable.to_int())
512+
self.scatter_select_ptr(ptrs, enable);
477513
// Cleared ☢️ *mut T Zone
478514
}
479515
}
516+
517+
/// Write pointers elementwise into a SIMD vector vector.
518+
///
519+
/// # Safety
520+
///
521+
/// Each write must satisfy the same conditions as [`core::ptr::write`].
522+
#[inline]
523+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
524+
pub unsafe fn scatter_ptr(self, dest: Simd<*mut T, LANES>) {
525+
// Safety: The caller is responsible for upholding all invariants
526+
unsafe { self.scatter_select_ptr(dest, Mask::splat(true)) }
527+
}
528+
529+
/// Conditionally write pointers elementwise into a SIMD vector vector.
530+
/// The mask `enable`s all `true` lanes and disables all `false` lanes.
531+
/// If a lane is disabled, the writing that lane is skipped.
532+
///
533+
/// # Safety
534+
///
535+
/// Enabled lanes must satisfy the same conditions as [`core::ptr::write`].
536+
#[inline]
537+
#[cfg_attr(miri, track_caller)] // even without panics, this helps for Miri backtraces
538+
pub unsafe fn scatter_select_ptr(self, dest: Simd<*mut T, LANES>, enable: Mask<isize, LANES>) {
539+
// Safety: The caller is responsible for upholding all invariants
540+
unsafe { intrinsics::simd_scatter(self, dest, enable.to_int()) }
541+
}
480542
}
481543

482544
impl<T, const LANES: usize> Copy for Simd<T, LANES>

0 commit comments

Comments
 (0)