Skip to content

Commit 2fab15d

Browse files
committed
fix stdsimd, add mask opt notes
1 parent f89f15f commit 2fab15d

File tree

4 files changed

+39
-32
lines changed

4 files changed

+39
-32
lines changed

src/distributions/integer.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl_nzint!(NonZeroU64, NonZeroU64::new);
114114
impl_nzint!(NonZeroU128, NonZeroU128::new);
115115
impl_nzint!(NonZeroUsize, NonZeroUsize::new);
116116

117-
macro_rules! intrinsic_impl {
117+
macro_rules! x86_intrinsic_impl {
118118
($($intrinsic:ident),+) => {$(
119119
/// Available only on x86/64 platforms
120120
impl Distribution<$intrinsic> for Standard {
@@ -156,12 +156,12 @@ macro_rules! simd_impl {
156156
simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64, usize, isize);
157157

158158
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
159-
intrinsic_impl!(__m128i, __m256i);
159+
x86_intrinsic_impl!(__m128i, __m256i);
160160
#[cfg(all(
161161
any(target_arch = "x86", target_arch = "x86_64"),
162162
feature = "simd_support"
163163
))]
164-
intrinsic_impl!(__m512i);
164+
x86_intrinsic_impl!(__m512i);
165165

166166
#[cfg(test)]
167167
mod tests {

src/distributions/other.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use serde::{Serialize, Deserialize};
2323
#[cfg(feature = "min_const_gen")]
2424
use core::mem::{self, MaybeUninit};
2525
#[cfg(feature = "simd_support")]
26-
use core::simd::{Mask, Simd, LaneCount, SupportedLaneCount, MaskElement, SimdElement};
26+
use core::simd::*;
2727

2828

2929
// ----- Sampling distributions -----
@@ -161,22 +161,29 @@ impl Distribution<bool> for Standard {
161161
/// let x = rng.gen::<mask8x16>().select(b, a);
162162
/// ```
163163
///
164-
/// Since most bits are unused you could also generate only as many bits as you need.
164+
/// Since most bits are unused you could also generate only as many bits as you need, i.e.:
165+
/// ```
166+
/// let x = u16x8::splat(rng.gen::<u8> as u16);
167+
/// let mask = u16x8::splat(1) << u16x8::from([0, 1, 2, 3, 4, 5, 6, 7]);
168+
/// let rand_mask = (x & mask).simd_eq(mask);
169+
/// ```
165170
///
166171
/// [`_mm_blendv_epi8`]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_epi8&ig_expand=514/
167172
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
168173
#[cfg(feature = "simd_support")]
169174
impl<T, const LANES: usize> Distribution<Mask<T, LANES>> for Standard
170175
where
171-
T: MaskElement + PartialOrd + SimdElement<Mask = T> + Default,
176+
T: MaskElement + Default,
172177
LaneCount<LANES>: SupportedLaneCount,
173178
Standard: Distribution<Simd<T, LANES>>,
179+
Simd<T, LANES>: SimdPartialOrd<Mask = Mask<T, LANES>>,
174180
{
175181
#[inline]
176182
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mask<T, LANES> {
177183
// `MaskElement` must be a signed integer, so this is equivalent
178184
// to the scalar `i32 < 0` method
179-
rng.gen().lanes_lt(Simd::default())
185+
let var = rng.gen::<Simd<T, LANES>>();
186+
var.simd_lt(Simd::default())
180187
}
181188
}
182189

src/distributions/uniform.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ macro_rules! uniform_simd_int_impl {
606606
{
607607
let low = *low_b.borrow();
608608
let high = *high_b.borrow();
609-
assert!(low.lanes_lt(high).all(), "Uniform::new called with `low >= high`");
609+
assert!(low.simd_lt(high).all(), "Uniform::new called with `low >= high`");
610610
UniformSampler::new_inclusive(low, high - Simd::splat(1))
611611
}
612612

@@ -618,15 +618,15 @@ macro_rules! uniform_simd_int_impl {
618618
{
619619
let low = *low_b.borrow();
620620
let high = *high_b.borrow();
621-
assert!(low.lanes_le(high).all(),
621+
assert!(low.simd_le(high).all(),
622622
"Uniform::new_inclusive called with `low > high`");
623623
let unsigned_max = Simd::splat(::core::$unsigned::MAX);
624624

625625
// NOTE: all `Simd` operations are inherently wrapping,
626626
// see https://doc.rust-lang.org/std/simd/struct.Simd.html
627627
let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast();
628628
// `% 0` will panic at runtime.
629-
let not_full_range = range.lanes_gt(Simd::splat(0));
629+
let not_full_range = range.simd_gt(Simd::splat(0));
630630
// replacing 0 with `unsigned_max` allows a faster `select`
631631
// with bitwise OR
632632
let modulo = not_full_range.select(range, unsigned_max);
@@ -660,7 +660,7 @@ macro_rules! uniform_simd_int_impl {
660660
let mut v: Simd<$unsigned, LANES> = rng.gen();
661661
loop {
662662
let (hi, lo) = v.wmul(range);
663-
let mask = lo.lanes_le(zone);
663+
let mask = lo.simd_le(zone);
664664
if mask.all() {
665665
let hi: Simd<$ty, LANES> = hi.cast();
666666
// wrapping addition
@@ -669,7 +669,7 @@ macro_rules! uniform_simd_int_impl {
669669
// When `range.eq(0).none()` the compare and blend
670670
// operations are avoided.
671671
let v: Simd<$ty, LANES> = v.cast();
672-
return range.lanes_gt(Simd::splat(0)).select(result, v);
672+
return range.simd_gt(Simd::splat(0)).select(result, v);
673673
}
674674
// Replace only the failing lanes
675675
v = mask.select(v, rng.gen());
@@ -1265,8 +1265,8 @@ mod tests {
12651265
($ty::splat(10), $ty::splat(127)),
12661266
($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)),
12671267
],
1268-
|x: $ty, y| x.lanes_le(y).all(),
1269-
|x: $ty, y| x.lanes_lt(y).all()
1268+
|x: $ty, y| x.simd_le(y).all(),
1269+
|x: $ty, y| x.simd_lt(y).all()
12701270
);)*
12711271
}};
12721272
}

src/distributions/utils.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,20 @@ macro_rules! wmul_impl_large {
9999
#[inline(always)]
100100
fn wmul(self, b: $ty) -> Self::Output {
101101
// needs wrapping multiplication
102-
const LOWER_MASK: $ty = <$ty>::splat(!0 >> $half);
103-
const HALF: $ty = <$ty>::splat($half);
104-
let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
105-
let mut t = low >> HALF;
106-
low &= LOWER_MASK;
107-
t += (self >> HALF) * (b & LOWER_MASK);
108-
low += (t & LOWER_MASK) << HALF;
109-
let mut high = t >> HALF;
110-
t = low >> HALF;
111-
low &= LOWER_MASK;
112-
t += (b >> HALF) * (self & LOWER_MASK);
113-
low += (t & LOWER_MASK) << HALF;
114-
high += t >> HALF;
115-
high += (self >> HALF) * (b >> HALF);
102+
let lower_mask = <$ty>::splat(!0 >> $half);
103+
let half = <$ty>::splat($half);
104+
let mut low = (self & lower_mask) * (b & lower_mask);
105+
let mut t = low >> half;
106+
low &= lower_mask;
107+
t += (self >> half) * (b & lower_mask);
108+
low += (t & lower_mask) << half;
109+
let mut high = t >> half;
110+
t = low >> half;
111+
low &= lower_mask;
112+
t += (b >> half) * (self & lower_mask);
113+
low += (t & lower_mask) << half;
114+
high += t >> half;
115+
high += (self >> half) * (b >> half);
116116

117117
(high, low)
118118
}
@@ -385,12 +385,12 @@ macro_rules! simd_impl {
385385

386386
#[inline(always)]
387387
fn all_lt(self, other: Self) -> bool {
388-
self.lanes_lt(other).all()
388+
self.simd_lt(other).all()
389389
}
390390

391391
#[inline(always)]
392392
fn all_le(self, other: Self) -> bool {
393-
self.lanes_le(other).all()
393+
self.simd_le(other).all()
394394
}
395395

396396
#[inline(always)]
@@ -405,12 +405,12 @@ macro_rules! simd_impl {
405405

406406
#[inline(always)]
407407
fn gt_mask(self, other: Self) -> Self::Mask {
408-
self.lanes_gt(other)
408+
self.simd_gt(other)
409409
}
410410

411411
#[inline(always)]
412412
fn ge_mask(self, other: Self) -> Self::Mask {
413-
self.lanes_ge(other)
413+
self.simd_ge(other)
414414
}
415415

416416
#[inline(always)]

0 commit comments

Comments
 (0)