Skip to content

Commit 082e3c8

Browse files
committed
Workaround simd_bitmask limitations
1 parent 4ca9f04 commit 082e3c8

File tree

3 files changed

+90
-25
lines changed

3 files changed

+90
-25
lines changed

crates/core_simd/src/masks/full_masks.rs

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,40 +207,108 @@ where
207207
}
208208

209209
#[inline]
210-
pub(crate) fn to_bitmask_integer(self) -> u64 {
211-
let resized = self.to_int().extend::<64>(T::FALSE);
210+
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
211+
where
212+
LaneCount<M>: SupportedLaneCount,
213+
{
214+
let resized = self.to_int().resize::<M>(T::FALSE);
212215

213-
// SAFETY: `resized` is an integer vector with length 64
214-
let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) };
216+
// Safety: `resized` is an integer vector with length M, which must match T
217+
let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
215218

216219
// LLVM assumes bit order should match endianness
217220
if cfg!(target_endian = "big") {
218-
bitmask.reverse_bits()
221+
bitmask.reverse_bits(M)
219222
} else {
220223
bitmask
221224
}
222225
}
223226

224227
#[inline]
225-
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
228+
unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
229+
where
230+
LaneCount<M>: SupportedLaneCount,
231+
{
226232
// LLVM assumes bit order should match endianness
227233
let bitmask = if cfg!(target_endian = "big") {
228-
bitmask.reverse_bits()
234+
bitmask.reverse_bits(M)
229235
} else {
230236
bitmask
231237
};
232238

233239
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
234-
let mask: Simd<T, 64> = unsafe {
240+
let mask: Simd<T, M> = unsafe {
235241
intrinsics::simd_select_bitmask(
236242
bitmask,
237-
Simd::<T, 64>::splat(T::TRUE),
238-
Simd::<T, 64>::splat(T::FALSE),
243+
Simd::<T, M>::splat(T::TRUE),
244+
Simd::<T, M>::splat(T::FALSE),
239245
)
240246
};
241247

242248
// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
243-
unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) }
249+
unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
250+
}
251+
252+
#[inline]
253+
pub(crate) fn to_bitmask_integer(self) -> u64 {
254+
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
255+
macro_rules! bitmask {
256+
{ $($ty:ty: $($len:literal),*;)* } => {
257+
match N {
258+
$($(
259+
// Safety: bitmask matches length
260+
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
261+
)*)*
262+
// Safety: bitmask matches length
263+
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
264+
}
265+
}
266+
}
267+
#[cfg(all_lane_counts)]
268+
bitmask! {
269+
u8: 1, 2, 3, 4, 5, 6, 7, 8;
270+
u16: 9, 10, 11, 12, 13, 14, 15, 16;
271+
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
272+
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
273+
}
274+
#[cfg(not(all_lane_counts))]
275+
bitmask! {
276+
u8: 1, 2, 4, 8;
277+
u16: 16;
278+
u32: 32;
279+
u64: 64;
280+
}
281+
}
282+
283+
#[inline]
284+
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
285+
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
286+
macro_rules! bitmask {
287+
{ $($ty:ty: $($len:literal),*;)* } => {
288+
match N {
289+
$($(
290+
// Safety: bitmask matches length
291+
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
292+
)*)*
293+
// Safety: bitmask matches length
294+
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
295+
}
296+
}
297+
}
298+
#[cfg(all_lane_counts)]
299+
bitmask! {
300+
u8: 1, 2, 3, 4, 5, 6, 7, 8;
301+
u16: 9, 10, 11, 12, 13, 14, 15, 16;
302+
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
303+
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
304+
}
305+
#[cfg(not(all_lane_counts))]
306+
bitmask! {
307+
u8: 1, 2, 4, 8;
308+
u16: 16;
309+
u32: 32;
310+
u64: 64;
311+
}
244312
}
245313

246314
#[inline]

crates/core_simd/src/swizzle.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,9 @@ where
350350
)
351351
}
352352

353-
/// Extend a vector.
353+
/// Resize a vector.
354354
///
355-
/// Extends the length of a vector, setting the new elements to `value`.
355+
/// If `M` > `N`, extends the length of a vector, setting the new elements to `value`.
356356
/// If `M` < `N`, truncates the vector to the first `M` elements.
357357
///
358358
/// ```
@@ -361,17 +361,17 @@ where
361361
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
362362
/// # use simd::u32x4;
363363
/// let x = u32x4::from_array([0, 1, 2, 3]);
364-
/// assert_eq!(x.extend::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
365-
/// assert_eq!(x.extend::<2>(9).to_array(), [0, 1]);
364+
/// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]);
365+
/// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]);
366366
/// ```
367367
#[inline]
368368
#[must_use = "method returns a new vector and does not mutate the original inputs"]
369-
pub fn extend<const M: usize>(self, value: T) -> Simd<T, M>
369+
pub fn resize<const M: usize>(self, value: T) -> Simd<T, M>
370370
where
371371
LaneCount<M>: SupportedLaneCount,
372372
{
373-
struct Extend<const N: usize>;
374-
impl<const N: usize, const M: usize> Swizzle<M> for Extend<N> {
373+
struct Resize<const N: usize>;
374+
impl<const N: usize, const M: usize> Swizzle<M> for Resize<N> {
375375
const INDEX: [usize; M] = const {
376376
let mut index = [0; M];
377377
let mut i = 0;
@@ -382,6 +382,6 @@ where
382382
index
383383
};
384384
}
385-
Extend::<N>::concat_swizzle(self, Simd::splat(value))
385+
Resize::<N>::concat_swizzle(self, Simd::splat(value))
386386
}
387387
}

crates/core_simd/tests/masks.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ macro_rules! test_mask_api {
1313
#[cfg(target_arch = "wasm32")]
1414
use wasm_bindgen_test::*;
1515

16-
use core_simd::simd::{Mask, Simd};
16+
use core_simd::simd::Mask;
1717

1818
#[test]
1919
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
@@ -124,17 +124,14 @@ macro_rules! test_mask_api {
124124

125125
#[test]
126126
fn roundtrip_bitmask_vector_conversion() {
127+
use core_simd::simd::ToBytes;
127128
let values = [
128129
true, false, false, true, false, false, true, false,
129130
true, true, false, false, false, false, false, true,
130131
];
131132
let mask = Mask::<$type, 16>::from_array(values);
132133
let bitmask = mask.to_bitmask_vector();
133-
if core::mem::size_of::<$type>() == 1 {
134-
assert_eq!(bitmask, Simd::from_array([0b01001001 as _, 0b10000011 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]));
135-
} else {
136-
assert_eq!(bitmask, Simd::from_array([0b1000001101001001 as _, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]));
137-
}
134+
assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]);
138135
assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask);
139136
}
140137
}

0 commit comments

Comments
 (0)