@@ -53,17 +53,26 @@ where
53
53
D : Dimension ,
54
54
{
55
55
pub ( crate ) fn layout_impl ( & self ) -> Layout {
56
- Layout :: new ( if self . is_standard_layout ( ) {
57
- if self . ndim ( ) <= 1 {
58
- FORDER | CORDER
56
+ let n = self . ndim ( ) ;
57
+ if self . is_standard_layout ( ) {
58
+ if n <= 1 {
59
+ Layout :: one_dimensional ( )
59
60
} else {
60
- CORDER
61
+ Layout :: c ( )
62
+ }
63
+ } else if n > 1 && self . raw_view ( ) . reversed_axes ( ) . is_standard_layout ( ) {
64
+ Layout :: f ( )
65
+ } else if n > 1 {
66
+ if self . stride_of ( Axis ( 0 ) ) == 1 {
67
+ Layout :: fpref ( )
68
+ } else if self . stride_of ( Axis ( n - 1 ) ) == 1 {
69
+ Layout :: cpref ( )
70
+ } else {
71
+ Layout :: none ( )
61
72
}
62
- } else if self . ndim ( ) > 1 && self . raw_view ( ) . reversed_axes ( ) . is_standard_layout ( ) {
63
- FORDER
64
73
} else {
65
- 0
66
- } )
74
+ Layout :: none ( )
75
+ }
67
76
}
68
77
}
69
78
@@ -587,6 +596,9 @@ pub struct Zip<Parts, D> {
587
596
parts : Parts ,
588
597
dimension : D ,
589
598
layout : Layout ,
599
+ /// The sum of the layout tendencies of the parts;
600
+ /// positive for c- and negative for f-layout preference.
601
+ layout_tendency : i32 ,
590
602
}
591
603
592
604
@@ -605,10 +617,12 @@ where
605
617
{
606
618
let array = p. into_producer ( ) ;
607
619
let dim = array. raw_dim ( ) ;
620
+ let layout = array. layout ( ) ;
608
621
Zip {
609
622
dimension : dim,
610
- layout : array . layout ( ) ,
623
+ layout,
611
624
parts : ( array, ) ,
625
+ layout_tendency : layout. tendency ( ) ,
612
626
}
613
627
}
614
628
}
@@ -661,24 +675,29 @@ where
661
675
self . dimension [ axis. index ( ) ]
662
676
}
663
677
678
+ fn prefer_f ( & self ) -> bool {
679
+ !self . layout . is ( CORDER ) && ( self . layout . is ( FORDER ) || self . layout_tendency < 0 )
680
+ }
681
+
664
682
/// Return an *approximation* to the max stride axis; if
665
683
/// component arrays disagree, there may be no choice better than the
666
684
/// others.
667
685
fn max_stride_axis ( & self ) -> Axis {
668
- let i = match self . layout . flag ( ) {
669
- FORDER => self
686
+ let i = if self . prefer_f ( ) {
687
+ self
670
688
. dimension
671
689
. slice ( )
672
690
. iter ( )
673
691
. rposition ( |& len| len > 1 )
674
- . unwrap_or ( self . dimension . ndim ( ) - 1 ) ,
692
+ . unwrap_or ( self . dimension . ndim ( ) - 1 )
693
+ } else {
675
694
/* corder or default */
676
- _ => self
695
+ self
677
696
. dimension
678
697
. slice ( )
679
698
. iter ( )
680
699
. position ( |& len| len > 1 )
681
- . unwrap_or ( 0 ) ,
700
+ . unwrap_or ( 0 )
682
701
} ;
683
702
Axis ( i)
684
703
}
@@ -699,6 +718,7 @@ where
699
718
self . apply_core_strided ( acc, function)
700
719
}
701
720
}
721
+
702
722
fn apply_core_contiguous < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
703
723
where
704
724
F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
@@ -717,7 +737,7 @@ where
717
737
FoldWhile :: Continue ( acc)
718
738
}
719
739
720
- fn apply_core_strided < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
740
+ fn apply_core_strided < F , Acc > ( & mut self , acc : Acc , function : F ) -> FoldWhile < Acc >
721
741
where
722
742
F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
723
743
P : ZippableTuple < Dim = D > ,
@@ -726,13 +746,27 @@ where
726
746
if n == 0 {
727
747
panic ! ( "Unreachable: ndim == 0 is contiguous" )
728
748
}
749
+ if n == 1 || self . layout_tendency >= 0 {
750
+ self . apply_core_strided_c ( acc, function)
751
+ } else {
752
+ self . apply_core_strided_f ( acc, function)
753
+ }
754
+ }
755
+
756
+ // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
757
+ fn apply_core_strided_c < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
758
+ where
759
+ F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
760
+ P : ZippableTuple < Dim = D > ,
761
+ {
762
+ let n = self . dimension . ndim ( ) ;
729
763
let unroll_axis = n - 1 ;
730
764
let inner_len = self . dimension [ unroll_axis] ;
731
765
self . dimension [ unroll_axis] = 1 ;
732
766
let mut index_ = self . dimension . first_index ( ) ;
733
767
let inner_strides = self . parts . stride_of ( unroll_axis) ;
768
+ // Loop unrolled over closest axis
734
769
while let Some ( index) = index_ {
735
- // Let's “unroll” the loop over the innermost axis
736
770
unsafe {
737
771
let ptr = self . parts . uget_ptr ( & index) ;
738
772
for i in 0 ..inner_len {
@@ -747,9 +781,40 @@ where
747
781
FoldWhile :: Continue ( acc)
748
782
}
749
783
784
+ // Non-contiguous but preference for F - unroll over Axis(0)
785
+ fn apply_core_strided_f < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
786
+ where
787
+ F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
788
+ P : ZippableTuple < Dim = D > ,
789
+ {
790
+ let unroll_axis = 0 ;
791
+ let inner_len = self . dimension [ unroll_axis] ;
792
+ self . dimension [ unroll_axis] = 1 ;
793
+ let index_ = self . dimension . first_index ( ) ;
794
+ let inner_strides = self . parts . stride_of ( unroll_axis) ;
795
+ // Loop unrolled over closest axis
796
+ if let Some ( mut index) = index_ {
797
+ loop {
798
+ unsafe {
799
+ let ptr = self . parts . uget_ptr ( & index) ;
800
+ for i in 0 ..inner_len {
801
+ let p = ptr. stride_offset ( inner_strides, i) ;
802
+ acc = fold_while ! ( function( acc, self . parts. as_ref( p) ) ) ;
803
+ }
804
+ }
805
+
806
+ if !self . dimension . next_for_f ( & mut index) {
807
+ break ;
808
+ }
809
+ }
810
+ }
811
+ self . dimension [ unroll_axis] = inner_len;
812
+ FoldWhile :: Continue ( acc)
813
+ }
814
+
750
815
pub ( crate ) fn uninitalized_for_current_layout < T > ( & self ) -> Array < MaybeUninit < T > , D >
751
816
{
752
- let is_f = ! self . layout . is ( CORDER ) && self . layout . is ( FORDER ) ;
817
+ let is_f = self . prefer_f ( ) ;
753
818
Array :: maybe_uninit ( self . dimension . clone ( ) . set_f ( is_f) )
754
819
}
755
820
}
@@ -995,8 +1060,9 @@ macro_rules! map_impl {
995
1060
let ( $( $p, ) * ) = self . parts;
996
1061
Zip {
997
1062
parts: ( $( $p, ) * part, ) ,
998
- layout: self . layout. and ( part_layout) ,
1063
+ layout: self . layout. intersect ( part_layout) ,
999
1064
dimension: self . dimension,
1065
+ layout_tendency: self . layout_tendency + part_layout. tendency( ) ,
1000
1066
}
1001
1067
}
1002
1068
@@ -1052,11 +1118,13 @@ macro_rules! map_impl {
1052
1118
dimension: d1,
1053
1119
layout: self . layout,
1054
1120
parts: p1,
1121
+ layout_tendency: self . layout_tendency,
1055
1122
} ,
1056
1123
Zip {
1057
1124
dimension: d2,
1058
1125
layout: self . layout,
1059
1126
parts: p2,
1127
+ layout_tendency: self . layout_tendency,
1060
1128
} )
1061
1129
}
1062
1130
}
0 commit comments