Skip to content

Commit 817e2f5

Browse files
author
Yuuichi Asahi
committed
Add all2all functionn and tests
1 parent b87725f commit 817e2f5

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// SPDX-FileCopyrightText: (C) The kokkos-fft development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
4+
5+
#ifndef KOKKOSFFT_DISTRIBUTED_ALL2ALL_HPP
6+
#define KOKKOSFFT_DISTRIBUTED_ALL2ALL_HPP
7+
8+
#include <mpi.h>
9+
#include <Kokkos_Core.hpp>
10+
#include <Kokkos_Profiling_ScopedRegion.hpp>
11+
#include "KokkosFFT_Distributed_MPI_Types.hpp"
12+
13+
namespace KokkosFFT {
14+
namespace Distributed {
15+
namespace Impl {
16+
17+
template <typename ExecutionSpace, typename ViewType>
18+
struct All2All {
19+
static_assert(ViewType::rank() >= 2,
20+
"All2All: View rank must be larger than or equal to 2");
21+
using value_type = typename ViewType::non_const_value_type;
22+
using LayoutType = typename ViewType::array_layout;
23+
24+
ExecutionSpace m_exec_space;
25+
MPI_Comm m_comm;
26+
MPI_Datatype m_mpi_data_type;
27+
28+
All2All(const ViewType& send, const ViewType& recv,
29+
const MPI_Comm& comm = MPI_COMM_WORLD,
30+
const ExecutionSpace exec_space = ExecutionSpace())
31+
: m_exec_space(exec_space),
32+
m_comm(comm),
33+
m_mpi_data_type(MPIDataType<value_type>::type()) {
34+
// Compute the outermost dimension size
35+
int send_count = 0;
36+
if (std::is_same_v<LayoutType, Kokkos::LayoutLeft>) {
37+
send_count = send.size() / send.extent(ViewType::rank() - 1);
38+
} else {
39+
send_count = send.size() / send.extent(0);
40+
}
41+
42+
::MPI_Alltoall(send.data(), send_count, m_mpi_data_type, recv.data(),
43+
send_count, m_mpi_data_type, m_comm);
44+
}
45+
};
46+
47+
template <typename ExecutionSpace, typename ViewType>
48+
void all2all(const ExecutionSpace& exec_space, const ViewType& send,
49+
const ViewType& recv, const MPI_Comm& comm) {
50+
static_assert(ViewType::rank() >= 2,
51+
"all2all: View rank must be larger than or equal to 2");
52+
Kokkos::Profiling::ScopedRegion region("KokkosFFT::Distributed::all2all");
53+
All2All(send, recv, comm, exec_space);
54+
}
55+
56+
} // namespace Impl
57+
} // namespace Distributed
58+
} // namespace KokkosFFT
59+
60+
#endif

distributed/unit_test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
set(SOURCES
66
Test_Main.cpp
7+
Test_All2All.cpp
78
Test_MPI_Types.cpp
89
)
910

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// SPDX-FileCopyrightText: (C) The kokkos-fft development team, see COPYRIGHT.md file
2+
//
3+
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
4+
5+
#include <gtest/gtest.h>
6+
#include <Kokkos_Core.hpp>
7+
#include "KokkosFFT_Distributed_All2All.hpp"
8+
9+
namespace {
10+
using execution_space = Kokkos::DefaultExecutionSpace;
11+
using test_types = ::testing::Types<std::pair<float, Kokkos::LayoutLeft>,
12+
std::pair<float, Kokkos::LayoutRight>,
13+
std::pair<double, Kokkos::LayoutLeft>,
14+
std::pair<double, Kokkos::LayoutRight>>;
15+
16+
// Basically the same fixtures, used for labeling tests
17+
template <typename T>
18+
struct TestAll2All : public ::testing::Test {
19+
using float_type = typename T::first_type;
20+
using layout_type = typename T::second_type;
21+
22+
int m_rank = 0;
23+
int m_nprocs = 1;
24+
25+
virtual void SetUp() {
26+
::MPI_Comm_rank(MPI_COMM_WORLD, &m_rank);
27+
::MPI_Comm_size(MPI_COMM_WORLD, &m_nprocs);
28+
}
29+
};
30+
31+
template <typename T, typename LayoutType>
32+
void test_all2all_view2D(int rank, int nprocs) {
33+
using View3DType = Kokkos::View<T***, LayoutType, execution_space>;
34+
35+
const std::size_t n0 = 16, n1 = 15;
36+
const std::size_t n0_local = ((n0 - 1) / nprocs) + 1;
37+
const std::size_t n1_local = ((n1 - 1) / nprocs) + 1;
38+
39+
int n0_buffer = 0, n1_buffer = 0, n2_buffer = 0;
40+
if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft>) {
41+
n0_buffer = n0_local;
42+
n1_buffer = n1_local;
43+
n2_buffer = nprocs;
44+
} else {
45+
n0_buffer = nprocs;
46+
n1_buffer = n0_local;
47+
n2_buffer = n1_local;
48+
}
49+
50+
View3DType send("send", n0_buffer, n1_buffer, n2_buffer),
51+
recv("recv", n0_buffer, n1_buffer, n2_buffer),
52+
ref("ref", n0_buffer, n1_buffer, n2_buffer);
53+
54+
auto h_send = Kokkos::create_mirror_view(send);
55+
auto h_ref = Kokkos::create_mirror_view(ref);
56+
57+
for (std::size_t i2 = 0; i2 < send.extent(2); i2++) {
58+
for (std::size_t i1 = 0; i1 < send.extent(1); i1++) {
59+
for (std::size_t i0 = 0; i0 < send.extent(0); i0++) {
60+
if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft>) {
61+
T value =
62+
static_cast<T>(rank * send.size() + i0 + i1 * send.extent(0) +
63+
i2 * send.extent(0) * send.extent(1));
64+
T value_T =
65+
static_cast<T>(i2 * send.size() + i0 + i1 * send.extent(0) +
66+
rank * send.extent(0) * send.extent(1));
67+
h_send(i0, i1, i2) = value;
68+
h_ref(i0, i1, i2) = value_T;
69+
} else {
70+
T value =
71+
static_cast<T>(rank * send.size() + i2 + i1 * send.extent(2) +
72+
i0 * send.extent(2) * send.extent(1));
73+
T value_T =
74+
static_cast<T>(i0 * send.size() + i2 + i1 * send.extent(2) +
75+
rank * send.extent(2) * send.extent(1));
76+
h_send(i0, i1, i2) = value;
77+
h_ref(i0, i1, i2) = value_T;
78+
}
79+
}
80+
}
81+
}
82+
83+
Kokkos::deep_copy(send, h_send);
84+
Kokkos::deep_copy(ref, h_ref);
85+
86+
execution_space exec;
87+
KokkosFFT::Distributed::Impl::all2all(exec, send, recv, MPI_COMM_WORLD);
88+
auto h_recv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), recv);
89+
90+
T epsilon = std::numeric_limits<T>::epsilon() * 100;
91+
for (std::size_t i2 = 0; i2 < send.extent(2); i2++) {
92+
for (std::size_t i1 = 0; i1 < send.extent(1); i1++) {
93+
for (std::size_t i0 = 0; i0 < send.extent(0); i0++) {
94+
auto diff = Kokkos::abs(h_recv(i0, i1, i2) - h_ref(i0, i1, i2));
95+
EXPECT_LE(diff, epsilon);
96+
}
97+
}
98+
}
99+
}
100+
101+
} // namespace
102+
103+
TYPED_TEST_SUITE(TestAll2All, test_types);
104+
105+
TYPED_TEST(TestAll2All, View2D) {
106+
using float_type = typename TestFixture::float_type;
107+
using layout_type = typename TestFixture::layout_type;
108+
test_all2all_view2D<float_type, layout_type>(this->m_rank, this->m_nprocs);
109+
}

0 commit comments

Comments
 (0)