@@ -623,7 +623,7 @@ int simplify_iteration_three_strides(const int nd,
623
623
auto str3_p = strides3[p];
624
624
shape_w.push_back (sh_p);
625
625
if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
626
- std::min (std::min ( str1_p, str2_p) , str3_p) < 0 )
626
+ std::min ({ str1_p, str2_p, str3_p} ) < 0 )
627
627
{
628
628
disp1 += str1_p * (sh_p - 1 );
629
629
str1_p = -str1_p;
@@ -716,6 +716,198 @@ contract_iter3(vecT shape, vecT strides1, vecT strides2, vecT strides3)
716
716
out_strides3, disp3);
717
717
}
718
718
719
+ /*
720
+ For purposes of iterating over pairs of elements of four arrays
721
+ with `shape` and strides `strides1`, `strides2`, `strides3`,
722
+ `strides4` given as pointers `simplify_iteration_four_strides(nd,
723
+ shape_ptr, strides1_ptr, strides2_ptr, strides3_ptr, strides4_ptr,
724
+ disp1, disp2, disp3, disp4)` may modify memory and returns new
725
+ length of these arrays.
726
+
727
+ The new shape and new strides, as well as the offset
728
+ `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3,
729
+ new_stride4, disp4)` are such that iterating over them will traverse the
730
+ same set of tuples of elements, possibly in a different order.
731
+ */
732
+ template <class ShapeTy , class StridesTy >
733
+ int simplify_iteration_four_strides (const int nd,
734
+ ShapeTy *shape,
735
+ StridesTy *strides1,
736
+ StridesTy *strides2,
737
+ StridesTy *strides3,
738
+ StridesTy *strides4,
739
+ StridesTy &disp1,
740
+ StridesTy &disp2,
741
+ StridesTy &disp3,
742
+ StridesTy &disp4)
743
+ {
744
+ disp1 = std::ptrdiff_t (0 );
745
+ disp2 = std::ptrdiff_t (0 );
746
+ if (nd < 2 )
747
+ return nd;
748
+
749
+ std::vector<int > pos (nd);
750
+ std::iota (pos.begin (), pos.end (), 0 );
751
+
752
+ std::stable_sort (
753
+ pos.begin (), pos.end (),
754
+ [&strides1, &strides2, &strides3, &strides4, &shape](int i1, int i2) {
755
+ auto abs_str1_i1 =
756
+ (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
757
+ auto abs_str1_i2 =
758
+ (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
759
+ auto abs_str2_i1 =
760
+ (strides2[i1] < 0 ) ? -strides2[i1] : strides2[i1];
761
+ auto abs_str2_i2 =
762
+ (strides2[i2] < 0 ) ? -strides2[i2] : strides2[i2];
763
+ auto abs_str3_i1 =
764
+ (strides3[i1] < 0 ) ? -strides3[i1] : strides3[i1];
765
+ auto abs_str3_i2 =
766
+ (strides3[i2] < 0 ) ? -strides3[i2] : strides3[i2];
767
+ auto abs_str4_i1 =
768
+ (strides4[i1] < 0 ) ? -strides4[i1] : strides4[i1];
769
+ auto abs_str4_i2 =
770
+ (strides4[i2] < 0 ) ? -strides4[i2] : strides4[i2];
771
+ return (abs_str1_i1 > abs_str1_i2) ||
772
+ ((abs_str1_i1 == abs_str1_i2) &&
773
+ ((abs_str2_i1 > abs_str2_i2) ||
774
+ ((abs_str2_i1 == abs_str2_i2) &&
775
+ ((abs_str3_i1 > abs_str3_i2) ||
776
+ ((abs_str3_i1 == abs_str3_i2) &&
777
+ ((abs_str4_i1 > abs_str4_i2) ||
778
+ ((abs_str4_i1 == abs_str4_i2) &&
779
+ (shape[i1] > shape[i2]))))))));
780
+ });
781
+
782
+ std::vector<ShapeTy> shape_w;
783
+ std::vector<StridesTy> strides1_w;
784
+ std::vector<StridesTy> strides2_w;
785
+ std::vector<StridesTy> strides3_w;
786
+ std::vector<StridesTy> strides4_w;
787
+
788
+ bool contractable = true ;
789
+ for (int i = 0 ; i < nd; ++i) {
790
+ auto p = pos[i];
791
+ auto sh_p = shape[p];
792
+ auto str1_p = strides1[p];
793
+ auto str2_p = strides2[p];
794
+ auto str3_p = strides3[p];
795
+ auto str4_p = strides4[p];
796
+ shape_w.push_back (sh_p);
797
+ if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 && str4_p <= 0 &&
798
+ std::min ({str1_p, str2_p, str3_p, str4_p}) < 0 )
799
+ {
800
+ disp1 += str1_p * (sh_p - 1 );
801
+ str1_p = -str1_p;
802
+ disp2 += str2_p * (sh_p - 1 );
803
+ str2_p = -str2_p;
804
+ disp3 += str3_p * (sh_p - 1 );
805
+ str3_p = -str3_p;
806
+ disp4 += str4_p * (sh_p - 1 );
807
+ str4_p = -str4_p;
808
+ }
809
+ if (str1_p < 0 || str2_p < 0 || str3_p < 0 || str4_p < 0 ) {
810
+ contractable = false ;
811
+ }
812
+ strides1_w.push_back (str1_p);
813
+ strides2_w.push_back (str2_p);
814
+ strides3_w.push_back (str3_p);
815
+ strides4_w.push_back (str4_p);
816
+ }
817
+ int nd_ = nd;
818
+ while (contractable) {
819
+ bool changed = false ;
820
+ for (int i = 0 ; i + 1 < nd_; ++i) {
821
+ StridesTy str1 = strides1_w[i + 1 ];
822
+ StridesTy str2 = strides2_w[i + 1 ];
823
+ StridesTy str3 = strides3_w[i + 1 ];
824
+ StridesTy str4 = strides4_w[i + 1 ];
825
+ StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
826
+ StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
827
+ StridesTy jump3 = strides3_w[i] - (shape_w[i + 1 ] - 1 ) * str3;
828
+ StridesTy jump4 = strides4_w[i] - (shape_w[i + 1 ] - 1 ) * str4;
829
+
830
+ if (jump1 == str1 && jump2 == str2 && jump3 == str3 &&
831
+ jump4 == str4) {
832
+ changed = true ;
833
+ shape_w[i] *= shape_w[i + 1 ];
834
+ for (int j = i; j < nd_; ++j) {
835
+ strides1_w[j] = strides1_w[j + 1 ];
836
+ }
837
+ for (int j = i; j < nd_; ++j) {
838
+ strides2_w[j] = strides2_w[j + 1 ];
839
+ }
840
+ for (int j = i; j < nd_; ++j) {
841
+ strides3_w[j] = strides3_w[j + 1 ];
842
+ }
843
+ for (int j = i; j < nd_; ++j) {
844
+ strides4_w[j] = strides4_w[j + 1 ];
845
+ }
846
+ for (int j = i + 1 ; j + 1 < nd_; ++j) {
847
+ shape_w[j] = shape_w[j + 1 ];
848
+ }
849
+ --nd_;
850
+ break ;
851
+ }
852
+ }
853
+ if (!changed)
854
+ break ;
855
+ }
856
+ for (int i = 0 ; i < nd_; ++i) {
857
+ shape[i] = shape_w[i];
858
+ }
859
+ for (int i = 0 ; i < nd_; ++i) {
860
+ strides1[i] = strides1_w[i];
861
+ }
862
+ for (int i = 0 ; i < nd_; ++i) {
863
+ strides2[i] = strides2_w[i];
864
+ }
865
+ for (int i = 0 ; i < nd_; ++i) {
866
+ strides3[i] = strides3_w[i];
867
+ }
868
+ for (int i = 0 ; i < nd_; ++i) {
869
+ strides4[i] = strides4_w[i];
870
+ }
871
+
872
+ return nd_;
873
+ }
874
+
875
+ template <typename T, class Error , typename vecT = std::vector<T>>
876
+ std::tuple<vecT, vecT, T, vecT, T, vecT, T, vecT, T>
877
+ contract_iter4 (vecT shape,
878
+ vecT strides1,
879
+ vecT strides2,
880
+ vecT strides3,
881
+ vecT strides4)
882
+ {
883
+ const size_t dim = shape.size ();
884
+ if (dim != strides1.size () || dim != strides2.size () ||
885
+ dim != strides3.size () || dim != strides4.size ())
886
+ {
887
+ throw Error (" Shape and strides must be of equal size." );
888
+ }
889
+ vecT out_shape = shape;
890
+ vecT out_strides1 = strides1;
891
+ vecT out_strides2 = strides2;
892
+ vecT out_strides3 = strides3;
893
+ vecT out_strides4 = strides4;
894
+ T disp1 (0 );
895
+ T disp2 (0 );
896
+ T disp3 (0 );
897
+ T disp4 (0 );
898
+
899
+ int nd = simplify_iteration_four_strides (
900
+ dim, out_shape.data (), out_strides1.data (), out_strides2.data (),
901
+ out_strides3.data (), out_strides4.data (), disp1, disp2, disp3, disp4);
902
+ out_shape.resize (nd);
903
+ out_strides1.resize (nd);
904
+ out_strides2.resize (nd);
905
+ out_strides3.resize (nd);
906
+ out_strides4.resize (nd);
907
+ return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2,
908
+ out_strides3, disp3, out_strides4, disp4);
909
+ }
910
+
719
911
} // namespace strides
720
912
} // namespace tensor
721
913
} // namespace dpctl
0 commit comments