Skip to content

Commit 84295c4

Browse files
committed
FEAT: Factor out traits SplitAt and SplitPreference
To be used by Zip and parallel Zip
1 parent 35e89f8 commit 84295c4

File tree

4 files changed

+77
-20
lines changed

4 files changed

+77
-20
lines changed

src/indexes.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
// except according to those terms.
88
use super::Dimension;
99
use crate::dimension::IntoDimension;
10-
use crate::zip::{Offset, Splittable};
10+
use crate::zip::Offset;
11+
use crate::split_at::SplitAt;
1112
use crate::Axis;
1213
use crate::Layout;
1314
use crate::NdProducer;

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ mod numeric_util;
179179
mod shape_builder;
180180
#[macro_use]
181181
mod slice;
182+
mod split_at;
182183
mod stacking;
183184
#[macro_use]
184185
mod zip;

src/split_at.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
2+
use crate::imp_prelude::*;
3+
4+
/// Arrays and similar that can be split along an axis
5+
pub(crate) trait SplitAt {
6+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) where Self: Sized;
7+
}
8+
9+
pub(crate) trait SplitPreference : SplitAt {
10+
fn can_split(&self) -> bool;
11+
fn size(&self) -> usize;
12+
fn split_preference(&self) -> (Axis, usize);
13+
fn split(self) -> (Self, Self) where Self: Sized {
14+
let (axis, index) = self.split_preference();
15+
self.split_at(axis, index)
16+
}
17+
}
18+
19+
impl<D> SplitAt for D
20+
where
21+
D: Dimension,
22+
{
23+
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
24+
let mut d1 = self;
25+
let mut d2 = d1.clone();
26+
let i = axis.index();
27+
let len = d1[i];
28+
d1[i] = index;
29+
d2[i] = len - index;
30+
(d1, d2)
31+
}
32+
}
33+
34+
impl<'a, A, D> SplitAt for ArrayViewMut<'a, A, D>
35+
where D: Dimension
36+
{
37+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
38+
self.split_at(axis, index)
39+
}
40+
}
41+
42+
43+
impl<A, D> SplitAt for RawArrayViewMut<A, D>
44+
where D: Dimension
45+
{
46+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
47+
self.split_at(axis, index)
48+
}
49+
}

src/zip/mod.rs

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use crate::NdIndex;
2020

2121
use crate::indexes::{indices, Indices};
2222
use crate::layout::{CORDER, FORDER};
23+
use crate::split_at::{SplitPreference, SplitAt};
2324

2425
use partial_array::PartialArray;
2526

@@ -92,25 +93,6 @@ where
9293
private_impl! {}
9394
}
9495

95-
pub trait Splittable: Sized {
96-
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self);
97-
}
98-
99-
impl<D> Splittable for D
100-
where
101-
D: Dimension,
102-
{
103-
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
104-
let mut d1 = self;
105-
let mut d2 = d1.clone();
106-
let i = axis.index();
107-
let len = d1[i];
108-
d1[i] = index;
109-
d2[i] = len - index;
110-
(d1, d2)
111-
}
112-
}
113-
11496
/// Argument conversion into a producer.
11597
///
11698
/// Slices and vectors can be used (equivalent to 1-dimensional array views).
@@ -1121,9 +1103,31 @@ macro_rules! map_impl {
11211103
pub fn split(self) -> (Self, Self) {
11221104
debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
11231105
debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
1106+
SplitPreference::split(self)
1107+
}
1108+
}
1109+
1110+
impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
1111+
where D: Dimension,
1112+
$($p: NdProducer<Dim=D> ,)*
1113+
{
1114+
fn can_split(&self) -> bool { self.size() > 1 }
1115+
1116+
fn size(&self) -> usize { self.size() }
1117+
1118+
fn split_preference(&self) -> (Axis, usize) {
11241119
// Always split in a way that preserves layout (if any)
11251120
let axis = self.max_stride_axis();
11261121
let index = self.len_of(axis) / 2;
1122+
(axis, index)
1123+
}
1124+
}
1125+
1126+
impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
1127+
where D: Dimension,
1128+
$($p: NdProducer<Dim=D> ,)*
1129+
{
1130+
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
11271131
let (p1, p2) = self.parts.split_at(axis, index);
11281132
let (d1, d2) = self.dimension.split_at(axis, index);
11291133
(Zip {
@@ -1139,7 +1143,9 @@ macro_rules! map_impl {
11391143
layout_tendency: self.layout_tendency,
11401144
})
11411145
}
1146+
11421147
}
1148+
11431149
)+
11441150
}
11451151
}

0 commit comments

Comments
 (0)