|
| 1 | +// SPDX-FileCopyrightText: (C) The kokkos-fft development team, see COPYRIGHT.md file |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception |
| 4 | + |
| 5 | +#include <gtest/gtest.h> |
| 6 | +#include <Kokkos_Core.hpp> |
| 7 | +#include "KokkosFFT_Distributed_All2All.hpp" |
| 8 | + |
| 9 | +namespace { |
| 10 | +using execution_space = Kokkos::DefaultExecutionSpace; |
| 11 | +using test_types = ::testing::Types<std::pair<float, Kokkos::LayoutLeft>, |
| 12 | + std::pair<float, Kokkos::LayoutRight>, |
| 13 | + std::pair<double, Kokkos::LayoutLeft>, |
| 14 | + std::pair<double, Kokkos::LayoutRight>>; |
| 15 | + |
| 16 | +// Basically the same fixtures, used for labeling tests |
| 17 | +template <typename T> |
| 18 | +struct TestAll2All : public ::testing::Test { |
| 19 | + using float_type = typename T::first_type; |
| 20 | + using layout_type = typename T::second_type; |
| 21 | + |
| 22 | + int m_rank = 0; |
| 23 | + int m_nprocs = 1; |
| 24 | + |
| 25 | + virtual void SetUp() { |
| 26 | + ::MPI_Comm_rank(MPI_COMM_WORLD, &m_rank); |
| 27 | + ::MPI_Comm_size(MPI_COMM_WORLD, &m_nprocs); |
| 28 | + } |
| 29 | +}; |
| 30 | + |
| 31 | +template <typename T, typename LayoutType> |
| 32 | +void test_all2all_view2D(int rank, int nprocs) { |
| 33 | + using View3DType = Kokkos::View<T***, LayoutType, execution_space>; |
| 34 | + |
| 35 | + const std::size_t n0 = 16, n1 = 15; |
| 36 | + const std::size_t n0_local = ((n0 - 1) / nprocs) + 1; |
| 37 | + const std::size_t n1_local = ((n1 - 1) / nprocs) + 1; |
| 38 | + |
| 39 | + int n0_buffer = 0, n1_buffer = 0, n2_buffer = 0; |
| 40 | + if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft>) { |
| 41 | + n0_buffer = n0_local; |
| 42 | + n1_buffer = n1_local; |
| 43 | + n2_buffer = nprocs; |
| 44 | + } else { |
| 45 | + n0_buffer = nprocs; |
| 46 | + n1_buffer = n0_local; |
| 47 | + n2_buffer = n1_local; |
| 48 | + } |
| 49 | + |
| 50 | + View3DType send("send", n0_buffer, n1_buffer, n2_buffer), |
| 51 | + recv("recv", n0_buffer, n1_buffer, n2_buffer), |
| 52 | + ref("ref", n0_buffer, n1_buffer, n2_buffer); |
| 53 | + |
| 54 | + auto h_send = Kokkos::create_mirror_view(send); |
| 55 | + auto h_ref = Kokkos::create_mirror_view(ref); |
| 56 | + |
| 57 | + for (std::size_t i2 = 0; i2 < send.extent(2); i2++) { |
| 58 | + for (std::size_t i1 = 0; i1 < send.extent(1); i1++) { |
| 59 | + for (std::size_t i0 = 0; i0 < send.extent(0); i0++) { |
| 60 | + if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft>) { |
| 61 | + T value = |
| 62 | + static_cast<T>(rank * send.size() + i0 + i1 * send.extent(0) + |
| 63 | + i2 * send.extent(0) * send.extent(1)); |
| 64 | + T value_T = |
| 65 | + static_cast<T>(i2 * send.size() + i0 + i1 * send.extent(0) + |
| 66 | + rank * send.extent(0) * send.extent(1)); |
| 67 | + h_send(i0, i1, i2) = value; |
| 68 | + h_ref(i0, i1, i2) = value_T; |
| 69 | + } else { |
| 70 | + T value = |
| 71 | + static_cast<T>(rank * send.size() + i2 + i1 * send.extent(2) + |
| 72 | + i0 * send.extent(2) * send.extent(1)); |
| 73 | + T value_T = |
| 74 | + static_cast<T>(i0 * send.size() + i2 + i1 * send.extent(2) + |
| 75 | + rank * send.extent(2) * send.extent(1)); |
| 76 | + h_send(i0, i1, i2) = value; |
| 77 | + h_ref(i0, i1, i2) = value_T; |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + Kokkos::deep_copy(send, h_send); |
| 84 | + Kokkos::deep_copy(ref, h_ref); |
| 85 | + |
| 86 | + execution_space exec; |
| 87 | + KokkosFFT::Distributed::Impl::all2all(exec, send, recv, MPI_COMM_WORLD); |
| 88 | + auto h_recv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), recv); |
| 89 | + |
| 90 | + T epsilon = std::numeric_limits<T>::epsilon() * 100; |
| 91 | + for (std::size_t i2 = 0; i2 < send.extent(2); i2++) { |
| 92 | + for (std::size_t i1 = 0; i1 < send.extent(1); i1++) { |
| 93 | + for (std::size_t i0 = 0; i0 < send.extent(0); i0++) { |
| 94 | + auto diff = Kokkos::abs(h_recv(i0, i1, i2) - h_ref(i0, i1, i2)); |
| 95 | + EXPECT_LE(diff, epsilon); |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +} // namespace |
| 102 | + |
| 103 | +TYPED_TEST_SUITE(TestAll2All, test_types); |
| 104 | + |
| 105 | +TYPED_TEST(TestAll2All, View2D) { |
| 106 | + using float_type = typename TestFixture::float_type; |
| 107 | + using layout_type = typename TestFixture::layout_type; |
| 108 | + test_all2all_view2D<float_type, layout_type>(this->m_rank, this->m_nprocs); |
| 109 | +} |
0 commit comments