- 
                Notifications
    You must be signed in to change notification settings 
- Fork 9
Introduce a helper to make a reference for transposed view #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a helper to make a reference for transposed view #349
Conversation
        
          
                common/unit_test/Test_Transpose.cpp
              
                Outdated
          
        
      | for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | ||
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | ||
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | ||
| for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) { | ||
| for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) { | ||
| std::array<std::size_t, 8> dst_indices{i0, i1, i2, i3, | ||
| i4, i5, i6, i7}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)), | ||
| dst_i3 = dst_indices.at(map.at(3)), | ||
| dst_i4 = dst_indices.at(map.at(4)), | ||
| dst_i5 = dst_indices.at(map.at(5)), | ||
| dst_i6 = dst_indices.at(map.at(6)), | ||
| dst_i7 = dst_indices.at(map.at(7)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | ||
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) && | ||
| dst_i6 < h_xT.extent(6) && dst_i7 < h_xT.extent(7)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5, | ||
| dst_i6, dst_i7) = | ||
| h_x(i0, i1, i2, i3, i4, i5, i6, i7); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | |
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | |
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | |
| for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) { | |
| for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) { | |
| std::array<std::size_t, 8> dst_indices{i0, i1, i2, i3, | |
| i4, i5, i6, i7}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)), | |
| dst_i3 = dst_indices.at(map.at(3)), | |
| dst_i4 = dst_indices.at(map.at(4)), | |
| dst_i5 = dst_indices.at(map.at(5)), | |
| dst_i6 = dst_indices.at(map.at(6)), | |
| dst_i7 = dst_indices.at(map.at(7)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | |
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) && | |
| dst_i6 < h_xT.extent(6) && dst_i7 < h_xT.extent(7)) { | |
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5, | |
| dst_i6, dst_i7) = | |
| h_x(i0, i1, i2, i3, i4, i5, i6, i7); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| std::array<std::size_t, ViewType1::rank> idx; | |
| std::array<std::size_t, ViewType1::rank> dst; | |
| for (idx[0] = 0; idx[0] < h_x.extent(0); idx[0]++) { | |
| for (idx[1] = 0; idx[1] < h_x.extent(1); idx[1]++) { | |
| for (idx[2] = 0; idx[2] < h_x.extent(2); idx[2]++) { | |
| for (idx[3] = 0; idx[3] < h_x.extent(3); idx[3]++) { | |
| for (idx[4] = 0; idx[4] < h_x.extent(4); idx[4]++) { | |
| for (idx[5] = 0; idx[5] < h_x.extent(5); idx[5]++) { | |
| for (idx[6] = 0; idx[6] < h_x.extent(6); idx[6]++) { | |
| for (idx[7] = 0; idx[7] < h_x.extent(7); idx[7]++) { | |
| bool in_bound = true; | |
| for(std::size_t i = 0; i < ViewType1::rank; ++i) { | |
| dst[i] = idx.at(map.at(i)); | |
| in_bound &= dst[i] < hxT.extent(i); | |
| } | |
| if (in_bound) { | |
| h_xT(dst[0], dst[1], dst[2], dst[3], dst[4], dst[5], dst[6], dst[7]) = | |
| h_x(idx[0], idx[1], idx[2], idx[3], idx[4], idx[5], idx[6], idx[7]); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | 
This wasn't tested but it should work.
I wonder if there is a way of avoiding all the constexpr if(ViewType1::rank == 8) and have a single loop for all dimension?
Maybe by writing a recursive template?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments.
I do not think it will work. The number of indices in accessors must be identical to the rank of the View.
I wonder if there is a way of avoiding all the constexpr if(ViewType1::rank == 8) and have a single loop for all dimension?
Maybe by writing a recursive template?
It is doable. We have already achieved this with https://github.com/kokkos/kokkos-fft/blob/cf844f58c320442b7d28ada26d42266839d11ec9/common/src/KokkosFFT_transpose.hpp#L198-#L239 and https://github.com/yasahi-hpc/distributed-FFT-for-kokkos/blob/57f35d25bc2efc4994f33a4042294261b6a68dba/distributed/src/KokkosFFT_Distributed_Utils.hpp#L550-#599
However, I would like to make this reference generator as simple as possible. In addition, I would like to make it in a different way than I did in the helper functions internally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe you can always use the 8 rank loop along with the access method that let you use 8 args to access a View however many dimensions it has.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have never tried access method.
So, something like this
    for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) {
      for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) {
        for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) {
          for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) {
            for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) {
              for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) {
                for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) {
                  for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) {
                    std::array<std::size_t, 8> src{i0, i1, i2, i3, i4, i5, i6, i7};
                    std::array<std::size_t, 8> dst = src;
                    bool in_bound = true;
                    for(std::size_t i = 0; i < ViewType1::rank; ++i) {
                      dst.at(i) = src.at(map.at(i));
                      in_bound &= dst.at(i) < hxT.extent(i);
                    }
                    if (in_bound) {
                      h_xT.access(dst[0], dst[1], dst[2], dst[3], dst[4], dst[5], dst[6], dst[7]) =
                          h_x.access(src[0], src[1], src[2], src[3], src[4], src[5], src[6], src[7]);
                    }
                  }
                }
              }
            }
          }
        }
      }
    }
  }There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had never used it before either, a shame there isn't an overload taking std::array, it would make the call much cleaner.
I wouldn't define the arrays inside the inner loop to avoid paying the construction and initialization cost at every step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had never used it before either, a shame there isn't an overload taking std::array, it would make the call much cleaner.
We can consider a overload taking Kokkos::Array
I wouldn't define the arrays inside the inner loop to avoid paying the construction and initialization cost at every step.
It is a small test function. These costs are negligible
| template <typename ViewType1, typename ViewType2, std::size_t DIM> | ||
| void make_transposed(const ViewType1& x, const ViewType2& xT, | ||
| const KokkosFFT::axis_type<DIM>& map) { | ||
| static_assert(ViewType1::rank() == DIM && ViewType2::rank() == DIM, | ||
| "make_transposed: Rank of Views must be equal to Rank of " | ||
| "transpose axes."); | ||
| auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x); | ||
| auto h_xT = Kokkos::create_mirror_view(xT); | ||
|  | ||
| if constexpr (ViewType1::rank == 2) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| h_xT(i1, i0) = h_x(i0, i1); | ||
| } | ||
| } | ||
| } else if constexpr (ViewType1::rank == 3) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| std::array<std::size_t, 3> dst_indices{i0, i1, i2}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2) = h_x(i0, i1, i2); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else if constexpr (ViewType1::rank == 4) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | ||
| std::array<std::size_t, 4> dst_indices{i0, i1, i2, i3}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)), | ||
| dst_i3 = dst_indices.at(map.at(3)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3) = h_x(i0, i1, i2, i3); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else if constexpr (ViewType1::rank == 5) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | ||
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | ||
| std::array<std::size_t, 5> dst_indices{i0, i1, i2, i3, i4}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)), | ||
| dst_i3 = dst_indices.at(map.at(3)), | ||
| dst_i4 = dst_indices.at(map.at(4)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | ||
| dst_i4 < h_xT.extent(4)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4) = | ||
| h_x(i0, i1, i2, i3, i4); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else if constexpr (ViewType1::rank == 6) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | ||
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | ||
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | ||
| std::array<std::size_t, 6> dst_indices{i0, i1, i2, i3, i4, i5}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)), | ||
| dst_i3 = dst_indices.at(map.at(3)), | ||
| dst_i4 = dst_indices.at(map.at(4)), | ||
| dst_i5 = dst_indices.at(map.at(5)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | ||
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5) = | ||
| h_x(i0, i1, i2, i3, i4, i5); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else if constexpr (ViewType1::rank == 7) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | ||
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | ||
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | ||
| for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) { | ||
| std::array<std::size_t, 7> dst_indices{i0, i1, i2, i3, | ||
| i4, i5, i6}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)), | ||
| dst_i3 = dst_indices.at(map.at(3)), | ||
| dst_i4 = dst_indices.at(map.at(4)), | ||
| dst_i5 = dst_indices.at(map.at(5)), | ||
| dst_i6 = dst_indices.at(map.at(6)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | ||
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) && | ||
| dst_i6 < h_xT.extent(6)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5, | ||
| dst_i6) = h_x(i0, i1, i2, i3, i4, i5, i6); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else if constexpr (ViewType1::rank == 8) { | ||
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | ||
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | ||
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | ||
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | ||
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | ||
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | ||
| for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) { | ||
| for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) { | ||
| std::array<std::size_t, 8> dst_indices{i0, i1, i2, i3, | ||
| i4, i5, i6, i7}; | ||
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | ||
| dst_i1 = dst_indices.at(map.at(1)), | ||
| dst_i2 = dst_indices.at(map.at(2)), | ||
| dst_i3 = dst_indices.at(map.at(3)), | ||
| dst_i4 = dst_indices.at(map.at(4)), | ||
| dst_i5 = dst_indices.at(map.at(5)), | ||
| dst_i6 = dst_indices.at(map.at(6)), | ||
| dst_i7 = dst_indices.at(map.at(7)); | ||
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | ||
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | ||
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) && | ||
| dst_i6 < h_xT.extent(6) && dst_i7 < h_xT.extent(7)) { | ||
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5, | ||
| dst_i6, dst_i7) = | ||
| h_x(i0, i1, i2, i3, i4, i5, i6, i7); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Kokkos::deep_copy(xT, h_xT); | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| template <typename ViewType1, typename ViewType2, std::size_t DIM> | |
| void make_transposed(const ViewType1& x, const ViewType2& xT, | |
| const KokkosFFT::axis_type<DIM>& map) { | |
| static_assert(ViewType1::rank() == DIM && ViewType2::rank() == DIM, | |
| "make_transposed: Rank of Views must be equal to Rank of " | |
| "transpose axes."); | |
| auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x); | |
| auto h_xT = Kokkos::create_mirror_view(xT); | |
| if constexpr (ViewType1::rank == 2) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| h_xT(i1, i0) = h_x(i0, i1); | |
| } | |
| } | |
| } else if constexpr (ViewType1::rank == 3) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| std::array<std::size_t, 3> dst_indices{i0, i1, i2}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2)) { | |
| h_xT(dst_i0, dst_i1, dst_i2) = h_x(i0, i1, i2); | |
| } | |
| } | |
| } | |
| } | |
| } else if constexpr (ViewType1::rank == 4) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | |
| std::array<std::size_t, 4> dst_indices{i0, i1, i2, i3}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)), | |
| dst_i3 = dst_indices.at(map.at(3)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3)) { | |
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3) = h_x(i0, i1, i2, i3); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } else if constexpr (ViewType1::rank == 5) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | |
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | |
| std::array<std::size_t, 5> dst_indices{i0, i1, i2, i3, i4}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)), | |
| dst_i3 = dst_indices.at(map.at(3)), | |
| dst_i4 = dst_indices.at(map.at(4)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | |
| dst_i4 < h_xT.extent(4)) { | |
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4) = | |
| h_x(i0, i1, i2, i3, i4); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } else if constexpr (ViewType1::rank == 6) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | |
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | |
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | |
| std::array<std::size_t, 6> dst_indices{i0, i1, i2, i3, i4, i5}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)), | |
| dst_i3 = dst_indices.at(map.at(3)), | |
| dst_i4 = dst_indices.at(map.at(4)), | |
| dst_i5 = dst_indices.at(map.at(5)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | |
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5)) { | |
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5) = | |
| h_x(i0, i1, i2, i3, i4, i5); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } else if constexpr (ViewType1::rank == 7) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | |
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | |
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | |
| for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) { | |
| std::array<std::size_t, 7> dst_indices{i0, i1, i2, i3, | |
| i4, i5, i6}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)), | |
| dst_i3 = dst_indices.at(map.at(3)), | |
| dst_i4 = dst_indices.at(map.at(4)), | |
| dst_i5 = dst_indices.at(map.at(5)), | |
| dst_i6 = dst_indices.at(map.at(6)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | |
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) && | |
| dst_i6 < h_xT.extent(6)) { | |
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5, | |
| dst_i6) = h_x(i0, i1, i2, i3, i4, i5, i6); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } else if constexpr (ViewType1::rank == 8) { | |
| for (std::size_t i0 = 0; i0 < h_x.extent(0); i0++) { | |
| for (std::size_t i1 = 0; i1 < h_x.extent(1); i1++) { | |
| for (std::size_t i2 = 0; i2 < h_x.extent(2); i2++) { | |
| for (std::size_t i3 = 0; i3 < h_x.extent(3); i3++) { | |
| for (std::size_t i4 = 0; i4 < h_x.extent(4); i4++) { | |
| for (std::size_t i5 = 0; i5 < h_x.extent(5); i5++) { | |
| for (std::size_t i6 = 0; i6 < h_x.extent(6); i6++) { | |
| for (std::size_t i7 = 0; i7 < h_x.extent(7); i7++) { | |
| std::array<std::size_t, 8> dst_indices{i0, i1, i2, i3, | |
| i4, i5, i6, i7}; | |
| std::size_t dst_i0 = dst_indices.at(map.at(0)), | |
| dst_i1 = dst_indices.at(map.at(1)), | |
| dst_i2 = dst_indices.at(map.at(2)), | |
| dst_i3 = dst_indices.at(map.at(3)), | |
| dst_i4 = dst_indices.at(map.at(4)), | |
| dst_i5 = dst_indices.at(map.at(5)), | |
| dst_i6 = dst_indices.at(map.at(6)), | |
| dst_i7 = dst_indices.at(map.at(7)); | |
| if (dst_i0 < h_xT.extent(0) && dst_i1 < h_xT.extent(1) && | |
| dst_i2 < h_xT.extent(2) && dst_i3 < h_xT.extent(3) && | |
| dst_i4 < h_xT.extent(4) && dst_i5 < h_xT.extent(5) && | |
| dst_i6 < h_xT.extent(6) && dst_i7 < h_xT.extent(7)) { | |
| h_xT(dst_i0, dst_i1, dst_i2, dst_i3, dst_i4, dst_i5, | |
| dst_i6, dst_i7) = | |
| h_x(i0, i1, i2, i3, i4, i5, i6, i7); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| Kokkos::deep_copy(xT, h_xT); | |
| } | |
| template <typename ViewType1, typename ViewType2, std::size_t DIM> | |
| void make_transposed(const ViewType1& x, const ViewType2& xT, | |
| const KokkosFFT::axis_type<DIM>& map) { | |
| static_assert(ViewType1::rank() == DIM && ViewType2::rank() == DIM, | |
| "make_transposed: Rank of Views must be equal to Rank of " | |
| "transpose axes."); | |
| auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, x); | |
| auto h_xT = Kokkos::create_mirror_view(xT); | |
| std::array<std::size_t, 8> idx{0}; | |
| std::array<std::size_t, 8> dst{0}; | |
| for (idx[0] = 0; idx[0] < h_x.extent(0); idx[0]++) { | |
| for (idx[1] = 0; idx[1] < h_x.extent(1); idx[1]++) { | |
| for (idx[2] = 0; idx[2] < h_x.extent(2); idx[2]++) { | |
| for (idx[3] = 0; idx[3] < h_x.extent(3); idx[3]++) { | |
| for (idx[4] = 0; idx[4] < h_x.extent(4); idx[4]++) { | |
| for (idx[5] = 0; idx[5] < h_x.extent(5); idx[5]++) { | |
| for (idx[6] = 0; idx[6] < h_x.extent(6); idx[6]++) { | |
| for (idx[7] = 0; idx[7] < h_x.extent(7); idx[7]++) { | |
| bool in_bound = true; | |
| for(std::size_t i = 0; i < ViewType1::rank; ++i) { | |
| dst[i] = idx.at(map.at(i)); | |
| in_bound &= dst[i] < hxT.extent(i); | |
| } | |
| if (in_bound) { | |
| // if i > ViewType1::rank: | |
| // - dst[i] is 0 since we haven't touched it in the previous loop | |
| // - idx[i] is also 0 because h_x.extent(i) is 1 | |
| // => We respect `access` constraints. | |
| h_xT.access(dst[0], dst[1], dst[2], dst[3], dst[4], dst[5], dst[6], dst[7]) = | |
| h_x.access(idx[0], idx[1], idx[2], idx[3], idx[4], idx[5], idx[6], idx[7]); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| Kokkos::deep_copy(xT, h_xT); | |
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
I do not like to manipulate idx inside the for loop update.
| Thank you for the review. | 
This PR aims at cleaning up the unit-tests of transpose helper.
As a preparation to #345
InViewandOutVieware different forTranspose3Don 4D View