Skip to content

Commit 6bf5128

Browse files
committed
Simplify interleave/deinterleave and fix for odd-length vectors.
1 parent 3183afb commit 6bf5128

File tree

1 file changed

+26
-48
lines changed

1 file changed

+26
-48
lines changed

crates/core_simd/src/swizzle.rs

Lines changed: 26 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,10 @@ where
265265

266266
/// Interleave two vectors.
267267
///
268-
/// Produces two vectors with lanes taken alternately from `self` and `other`.
268+
/// The resulting vectors contain lanes taken alternatively from `self` and `other`, first
269+
/// filling the first result, and then the second.
269270
///
270-
/// The first result contains the first `LANES / 2` lanes from `self` and `other`,
271-
/// alternating, starting with the first lane of `self`.
272-
///
273-
/// The second result contains the last `LANES / 2` lanes from `self` and `other`,
274-
/// alternating, starting with the lane `LANES / 2` from the start of `self`.
271+
/// The reverse of this operation is [`Simd::deinterleave`].
275272
///
276273
/// ```
277274
/// #![feature(portable_simd)]
@@ -285,29 +282,17 @@ where
285282
#[inline]
286283
#[must_use = "method returns a new vector and does not mutate the original inputs"]
287284
pub fn interleave(self, other: Self) -> (Self, Self) {
288-
const fn lo<const LANES: usize>() -> [Which; LANES] {
289-
let mut idx = [Which::First(0); LANES];
290-
let mut i = 0;
291-
while i < LANES {
292-
let offset = i / 2;
293-
idx[i] = if i % 2 == 0 {
294-
Which::First(offset)
295-
} else {
296-
Which::Second(offset)
297-
};
298-
i += 1;
299-
}
300-
idx
301-
}
302-
const fn hi<const LANES: usize>() -> [Which; LANES] {
285+
const fn interleave<const LANES: usize>(high: bool) -> [Which; LANES] {
303286
let mut idx = [Which::First(0); LANES];
304287
let mut i = 0;
305288
while i < LANES {
306-
let offset = (LANES + i) / 2;
307-
idx[i] = if i % 2 == 0 {
308-
Which::First(offset)
289+
// Treat the source as a concatenated vector
290+
let dst_index = if high { i + LANES } else { i };
291+
let src_index = dst_index / 2 + (dst_index % 2) * LANES;
292+
idx[i] = if src_index < LANES {
293+
Which::First(src_index)
309294
} else {
310-
Which::Second(offset)
295+
Which::Second(src_index % LANES)
311296
};
312297
i += 1;
313298
}
@@ -318,18 +303,14 @@ where
318303
struct Hi;
319304

320305
impl<const LANES: usize> Swizzle2<LANES, LANES> for Lo {
321-
const INDEX: [Which; LANES] = lo::<LANES>();
306+
const INDEX: [Which; LANES] = interleave::<LANES>(false);
322307
}
323308

324309
impl<const LANES: usize> Swizzle2<LANES, LANES> for Hi {
325-
const INDEX: [Which; LANES] = hi::<LANES>();
310+
const INDEX: [Which; LANES] = interleave::<LANES>(true);
326311
}
327312

328-
if LANES == 1 {
329-
(self, other)
330-
} else {
331-
(Lo::swizzle2(self, other), Hi::swizzle2(self, other))
332-
}
313+
(Lo::swizzle2(self, other), Hi::swizzle2(self, other))
333314
}
334315

335316
/// Deinterleave two vectors.
@@ -340,6 +321,8 @@ where
340321
/// The second result takes every other lane of `self` and then `other`, starting with
341322
/// the second lane.
342323
///
324+
/// The reverse of this operation is [`Simd::interleave`].
325+
///
343326
/// ```
344327
/// #![feature(portable_simd)]
345328
/// # use core::simd::Simd;
@@ -352,22 +335,17 @@ where
352335
#[inline]
353336
#[must_use = "method returns a new vector and does not mutate the original inputs"]
354337
pub fn deinterleave(self, other: Self) -> (Self, Self) {
355-
const fn even<const LANES: usize>() -> [Which; LANES] {
356-
let mut idx = [Which::First(0); LANES];
357-
let mut i = 0;
358-
while i < LANES / 2 {
359-
idx[i] = Which::First(2 * i);
360-
idx[i + LANES / 2] = Which::Second(2 * i);
361-
i += 1;
362-
}
363-
idx
364-
}
365-
const fn odd<const LANES: usize>() -> [Which; LANES] {
338+
const fn deinterleave<const LANES: usize>(second: bool) -> [Which; LANES] {
366339
let mut idx = [Which::First(0); LANES];
367340
let mut i = 0;
368-
while i < LANES / 2 {
369-
idx[i] = Which::First(2 * i + 1);
370-
idx[i + LANES / 2] = Which::Second(2 * i + 1);
341+
while i < LANES {
342+
// Treat the source as a concatenated vector
343+
let src_index = i * 2 + if second { 1 } else { 0 };
344+
idx[i] = if src_index < LANES {
345+
Which::First(src_index)
346+
} else {
347+
Which::Second(src_index % LANES)
348+
};
371349
i += 1;
372350
}
373351
idx
@@ -377,11 +355,11 @@ where
377355
struct Odd;
378356

379357
impl<const LANES: usize> Swizzle2<LANES, LANES> for Even {
380-
const INDEX: [Which; LANES] = even::<LANES>();
358+
const INDEX: [Which; LANES] = deinterleave::<LANES>(false);
381359
}
382360

383361
impl<const LANES: usize> Swizzle2<LANES, LANES> for Odd {
384-
const INDEX: [Which; LANES] = odd::<LANES>();
362+
const INDEX: [Which; LANES] = deinterleave::<LANES>(true);
385363
}
386364

387365
if LANES == 1 {

0 commit comments

Comments
 (0)