Skip to content

Commit a659235

Browse files
committed
FEAT: Keep track of "layout tendency" in Zip for better performance
Support both unroll over c- and f-layout preferred axis in Zip inner loop (the fallback when inputs are not all contiguous and same layout). Keep a tendency score when building the Zip, so that we know if the inputs are tending to be c- or f- layout. This improves performance on the just added zip_indexed_ff benchmark, so that it seems to match its (already fast) cc counterpart.
1 parent 90ef196 commit a659235

File tree

3 files changed

+120
-32
lines changed

3 files changed

+120
-32
lines changed

src/layout/layoutfmt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
use super::Layout;
1010

11-
const LAYOUT_NAMES: &[&str] = &["C", "F"];
11+
const LAYOUT_NAMES: &[&str] = &["C", "F", "c", "f"];
1212

1313
use std::fmt;
1414

src/layout/mod.rs

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,71 @@
11
mod layoutfmt;
22

3-
// public struct but users don't interact with it
3+
// Layout it a bitset used for internal layout description of
4+
// arrays, producers and sets of producers.
5+
// The type is public but users don't interact with it.
46
#[doc(hidden)]
57
/// Memory layout description
68
#[derive(Copy, Clone)]
79
pub struct Layout(u32);
810

911
impl Layout {
10-
#[inline(always)]
11-
pub(crate) fn new(x: u32) -> Self {
12-
Layout(x)
13-
}
14-
1512
#[inline(always)]
1613
pub(crate) fn is(self, flag: u32) -> bool {
1714
self.0 & flag != 0
1815
}
16+
17+
/// Return layout common to both inputs
1918
#[inline(always)]
20-
pub(crate) fn and(self, flag: Layout) -> Layout {
21-
Layout(self.0 & flag.0)
19+
pub(crate) fn intersect(self, other: Layout) -> Layout {
20+
Layout(self.0 & other.0)
2221
}
2322

23+
/// Return a layout that simultaneously "is" what both of the inputs are
2424
#[inline(always)]
25-
pub(crate) fn flag(self) -> u32 {
26-
self.0
25+
pub(crate) fn also(self, other: Layout) -> Layout {
26+
Layout(self.0 | other.0)
2727
}
2828

2929
#[inline(always)]
3030
pub(crate) fn one_dimensional() -> Layout {
31-
Layout(CORDER | FORDER)
31+
Layout::c().also(Layout::f())
3232
}
3333

3434
#[inline(always)]
3535
pub(crate) fn c() -> Layout {
36-
Layout(CORDER)
36+
Layout(CORDER | CPREFER)
3737
}
3838

3939
#[inline(always)]
4040
pub(crate) fn f() -> Layout {
41-
Layout(FORDER)
41+
Layout(FORDER | FPREFER)
42+
}
43+
44+
#[inline(always)]
45+
pub(crate) fn cpref() -> Layout {
46+
Layout(CPREFER)
47+
}
48+
49+
#[inline(always)]
50+
pub(crate) fn fpref() -> Layout {
51+
Layout(FPREFER)
4252
}
4353

4454
#[inline(always)]
4555
pub(crate) fn none() -> Layout {
4656
Layout(0)
4757
}
58+
59+
/// A simple "score" method which scores positive for preferring C-order, negative for F-order
60+
/// Subject to change when we can describe other layouts
61+
pub(crate) fn tendency(self) -> i32 {
62+
(self.is(CORDER) as i32 - self.is(FORDER) as i32) +
63+
(self.is(CPREFER) as i32 - self.is(FPREFER) as i32)
64+
65+
}
4866
}
4967

5068
pub const CORDER: u32 = 0b01;
5169
pub const FORDER: u32 = 0b10;
70+
pub const CPREFER: u32 = 0b0100;
71+
pub const FPREFER: u32 = 0b1000;

src/zip/mod.rs

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,26 @@ where
5353
D: Dimension,
5454
{
5555
pub(crate) fn layout_impl(&self) -> Layout {
56-
Layout::new(if self.is_standard_layout() {
57-
if self.ndim() <= 1 {
58-
FORDER | CORDER
56+
let n = self.ndim();
57+
if self.is_standard_layout() {
58+
if n <= 1 {
59+
Layout::one_dimensional()
5960
} else {
60-
CORDER
61+
Layout::c()
62+
}
63+
} else if n > 1 && self.raw_view().reversed_axes().is_standard_layout() {
64+
Layout::f()
65+
} else if n > 1 {
66+
if self.stride_of(Axis(0)) == 1 {
67+
Layout::fpref()
68+
} else if self.stride_of(Axis(n - 1)) == 1 {
69+
Layout::cpref()
70+
} else {
71+
Layout::none()
6172
}
62-
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
63-
FORDER
6473
} else {
65-
0
66-
})
74+
Layout::none()
75+
}
6776
}
6877
}
6978

@@ -587,6 +596,9 @@ pub struct Zip<Parts, D> {
587596
parts: Parts,
588597
dimension: D,
589598
layout: Layout,
599+
/// The sum of the layout tendencies of the parts;
600+
/// positive for c- and negative for f-layout preference.
601+
layout_tendency: i32,
590602
}
591603

592604

@@ -605,10 +617,12 @@ where
605617
{
606618
let array = p.into_producer();
607619
let dim = array.raw_dim();
620+
let layout = array.layout();
608621
Zip {
609622
dimension: dim,
610-
layout: array.layout(),
623+
layout,
611624
parts: (array,),
625+
layout_tendency: layout.tendency(),
612626
}
613627
}
614628
}
@@ -661,24 +675,29 @@ where
661675
self.dimension[axis.index()]
662676
}
663677

678+
fn prefer_f(&self) -> bool {
679+
!self.layout.is(CORDER) && (self.layout.is(FORDER) || self.layout_tendency < 0)
680+
}
681+
664682
/// Return an *approximation* to the max stride axis; if
665683
/// component arrays disagree, there may be no choice better than the
666684
/// others.
667685
fn max_stride_axis(&self) -> Axis {
668-
let i = match self.layout.flag() {
669-
FORDER => self
686+
let i = if self.prefer_f() {
687+
self
670688
.dimension
671689
.slice()
672690
.iter()
673691
.rposition(|&len| len > 1)
674-
.unwrap_or(self.dimension.ndim() - 1),
692+
.unwrap_or(self.dimension.ndim() - 1)
693+
} else {
675694
/* corder or default */
676-
_ => self
695+
self
677696
.dimension
678697
.slice()
679698
.iter()
680699
.position(|&len| len > 1)
681-
.unwrap_or(0),
700+
.unwrap_or(0)
682701
};
683702
Axis(i)
684703
}
@@ -699,6 +718,7 @@ where
699718
self.apply_core_strided(acc, function)
700719
}
701720
}
721+
702722
fn apply_core_contiguous<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
703723
where
704724
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
@@ -717,7 +737,7 @@ where
717737
FoldWhile::Continue(acc)
718738
}
719739

720-
fn apply_core_strided<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
740+
fn apply_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
721741
where
722742
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
723743
P: ZippableTuple<Dim = D>,
@@ -726,13 +746,27 @@ where
726746
if n == 0 {
727747
panic!("Unreachable: ndim == 0 is contiguous")
728748
}
749+
if n == 1 || self.layout_tendency >= 0 {
750+
self.apply_core_strided_c(acc, function)
751+
} else {
752+
self.apply_core_strided_f(acc, function)
753+
}
754+
}
755+
756+
// Non-contiguous but preference for C - unroll over Axis(ndim - 1)
757+
fn apply_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
758+
where
759+
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
760+
P: ZippableTuple<Dim = D>,
761+
{
762+
let n = self.dimension.ndim();
729763
let unroll_axis = n - 1;
730764
let inner_len = self.dimension[unroll_axis];
731765
self.dimension[unroll_axis] = 1;
732766
let mut index_ = self.dimension.first_index();
733767
let inner_strides = self.parts.stride_of(unroll_axis);
768+
// Loop unrolled over closest axis
734769
while let Some(index) = index_ {
735-
// Let's “unroll” the loop over the innermost axis
736770
unsafe {
737771
let ptr = self.parts.uget_ptr(&index);
738772
for i in 0..inner_len {
@@ -747,9 +781,40 @@ where
747781
FoldWhile::Continue(acc)
748782
}
749783

784+
// Non-contiguous but preference for F - unroll over Axis(0)
785+
fn apply_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
786+
where
787+
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
788+
P: ZippableTuple<Dim = D>,
789+
{
790+
let unroll_axis = 0;
791+
let inner_len = self.dimension[unroll_axis];
792+
self.dimension[unroll_axis] = 1;
793+
let index_ = self.dimension.first_index();
794+
let inner_strides = self.parts.stride_of(unroll_axis);
795+
// Loop unrolled over closest axis
796+
if let Some(mut index) = index_ {
797+
loop {
798+
unsafe {
799+
let ptr = self.parts.uget_ptr(&index);
800+
for i in 0..inner_len {
801+
let p = ptr.stride_offset(inner_strides, i);
802+
acc = fold_while!(function(acc, self.parts.as_ref(p)));
803+
}
804+
}
805+
806+
if !self.dimension.next_for_f(&mut index) {
807+
break;
808+
}
809+
}
810+
}
811+
self.dimension[unroll_axis] = inner_len;
812+
FoldWhile::Continue(acc)
813+
}
814+
750815
pub(crate) fn uninitalized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
751816
{
752-
let is_f = !self.layout.is(CORDER) && self.layout.is(FORDER);
817+
let is_f = self.prefer_f();
753818
Array::maybe_uninit(self.dimension.clone().set_f(is_f))
754819
}
755820
}
@@ -995,8 +1060,9 @@ macro_rules! map_impl {
9951060
let ($($p,)*) = self.parts;
9961061
Zip {
9971062
parts: ($($p,)* part, ),
998-
layout: self.layout.and(part_layout),
1063+
layout: self.layout.intersect(part_layout),
9991064
dimension: self.dimension,
1065+
layout_tendency: self.layout_tendency + part_layout.tendency(),
10001066
}
10011067
}
10021068

@@ -1052,11 +1118,13 @@ macro_rules! map_impl {
10521118
dimension: d1,
10531119
layout: self.layout,
10541120
parts: p1,
1121+
layout_tendency: self.layout_tendency,
10551122
},
10561123
Zip {
10571124
dimension: d2,
10581125
layout: self.layout,
10591126
parts: p2,
1127+
layout_tendency: self.layout_tendency,
10601128
})
10611129
}
10621130
}

0 commit comments

Comments
 (0)