Skip to content

Commit 5a7f814

Browse files
committed
coll/base,tuned: introduce allgather_reduce allreduce algorithm
This patch introduces a new allreduce algorithm implemented as an allgather followed by local reduction. The change is motivated by the longer latency of tcp/EFA traffic. Current allreduce algorithms require a round trip to and from a selected root process. This algorithm avoids the round trip over network and therefore reduces total latency. However, this communication pattern is not scalable for large communicators, and should only be used for inter-node allreduce. Co-authored-by: Matt Koop <mkoop@amazon.com> Co-authored-by: Wenduo Wang <wenduwan@mazon.com> Signed-off-by: Wenduo Wang <wenduwan@amazon.com>
1 parent 8514e71 commit 5a7f814

File tree

3 files changed

+118
-0
lines changed

3 files changed

+118
-0
lines changed

ompi/mca/coll/base/coll_base_allreduce.c

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
* Copyright (c) 2018 Siberian State University of Telecommunications
1919
* and Information Science. All rights reserved.
2020
* Copyright (c) 2022 Cisco Systems, Inc. All rights reserved.
21+
* Copyright (c) Amazon.com, Inc. or its affiliates.
22+
* All rights reserved.
2123
* $COPYRIGHT$
2224
*
2325
* Additional copyrights may follow
@@ -1245,4 +1247,116 @@ int ompi_coll_base_allreduce_intra_redscat_allgather(
12451247
return err;
12461248
}
12471249

1250+
/**
1251+
* A greedy algorithm to exchange data among processes in the communicator via
1252+
* an allgather pattern, followed by a local reduction on each process. This
1253+
* avoids the round trip in a rooted communication pattern, e.g. reduce on the
1254+
* root and then broadcast to peers.
1255+
*
1256+
* This algorithm supports both commutative and non-commutative MPI operations.
1257+
* For non-commutative operations the reduction is applied to the data in the
1258+
* same rank order, e.g. rank 0, rank 1, ... rank N, on each process.
1259+
*
1260+
* This algorithm benefits inter-node allreduce over a high-latency network.
1261+
* Caution is needed on larger communicators(n) and data sizes(m), which will
1262+
* result in m*n^2 total traffic and potential network congestion.
1263+
*/
1264+
int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf, int count,
1265+
struct ompi_datatype_t *dtype,
1266+
struct ompi_op_t *op,
1267+
struct ompi_communicator_t *comm,
1268+
mca_coll_base_module_t *module)
1269+
{
1270+
char *send_buf = (void *) sbuf;
1271+
int comm_size = ompi_comm_size(comm);
1272+
int err = MPI_SUCCESS;
1273+
int rank = ompi_comm_rank(comm);
1274+
bool commutative = ompi_op_is_commute(op);
1275+
ompi_request_t **reqs;
1276+
1277+
if (sbuf == MPI_IN_PLACE) {
1278+
send_buf = rbuf;
1279+
}
1280+
1281+
/* Allocate a large-enough buffer to receive from everyone else */
1282+
char *tmp_buf = NULL, *tmp_buf_raw = NULL, *tmp_recv = NULL;
1283+
ptrdiff_t lb, extent, dsize, gap = 0;
1284+
ompi_datatype_get_extent(dtype, &lb, &extent);
1285+
dsize = opal_datatype_span(&dtype->super, count * comm_size, &gap);
1286+
tmp_buf_raw = (char *) malloc(dsize);
1287+
if (NULL == tmp_buf_raw) {
1288+
return OMPI_ERR_OUT_OF_RESOURCE;
1289+
}
1290+
1291+
if (commutative) {
1292+
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
1293+
}
1294+
1295+
tmp_buf = tmp_buf_raw - gap;
1296+
1297+
/* Requests for send to AND receive from everyone else */
1298+
int reqs_needed = (comm_size - 1) * 2;
1299+
reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed);
1300+
1301+
ptrdiff_t incr = extent * count;
1302+
tmp_recv = (char *) tmp_buf;
1303+
1304+
/* Exchange data with peer processes */
1305+
int req_index = 0, peer_rank = 0;
1306+
for (int i = 1; i < comm_size; ++i) {
1307+
peer_rank = (rank + i) % comm_size;
1308+
tmp_recv = tmp_buf + (peer_rank * incr);
1309+
err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
1310+
comm, &reqs[req_index++]));
1311+
if (MPI_SUCCESS != err) {
1312+
goto err_hndl;
1313+
}
1314+
1315+
err = MCA_PML_CALL(isend(send_buf, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
1316+
MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index++]));
1317+
if (MPI_SUCCESS != err) {
1318+
goto err_hndl;
1319+
}
1320+
}
1321+
1322+
err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE);
1323+
1324+
/* Prepare for local reduction */
1325+
peer_rank = 0;
1326+
if (!commutative) {
1327+
/* For non-commutative operations, ensure the reduction always starts from Rank 0's data */
1328+
memcpy(rbuf, 0 == rank ? send_buf : tmp_buf, incr);
1329+
peer_rank = 1;
1330+
}
1331+
1332+
char *inbuf;
1333+
for (; peer_rank < comm_size; peer_rank++) {
1334+
inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr);
1335+
ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype);
1336+
}
1337+
1338+
err_hndl:
1339+
if (NULL != tmp_buf_raw)
1340+
free(tmp_buf_raw);
1341+
1342+
if (NULL != reqs) {
1343+
if (MPI_ERR_IN_STATUS == err) {
1344+
for (int i = 0; i < reqs_needed; i++) {
1345+
if (MPI_REQUEST_NULL == reqs[i])
1346+
continue;
1347+
if (MPI_ERR_PENDING == reqs[i]->req_status.MPI_ERROR)
1348+
continue;
1349+
if (MPI_SUCCESS != reqs[i]->req_status.MPI_ERROR) {
1350+
err = reqs[i]->req_status.MPI_ERROR;
1351+
break;
1352+
}
1353+
}
1354+
}
1355+
ompi_coll_base_free_reqs(reqs, reqs_needed);
1356+
}
1357+
1358+
/* All done */
1359+
return err;
1360+
}
1361+
12481362
/* copied function (with appropriate renaming) ends here */

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ int ompi_coll_base_allreduce_intra_ring(ALLREDUCE_ARGS);
210210
int ompi_coll_base_allreduce_intra_ring_segmented(ALLREDUCE_ARGS, uint32_t segsize);
211211
int ompi_coll_base_allreduce_intra_basic_linear(ALLREDUCE_ARGS);
212212
int ompi_coll_base_allreduce_intra_redscat_allgather(ALLREDUCE_ARGS);
213+
int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);
213214

214215
/* AlltoAll */
215216
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);

ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ static const mca_base_var_enum_value_t allreduce_algorithms[] = {
4242
{4, "ring"},
4343
{5, "segmented_ring"},
4444
{6, "rabenseifner"},
45+
{7, "allgather_reduce"},
4546
{0, NULL}
4647
};
4748

@@ -146,6 +147,8 @@ int ompi_coll_tuned_allreduce_intra_do_this(const void *sbuf, void *rbuf, int co
146147
return ompi_coll_base_allreduce_intra_ring_segmented(sbuf, rbuf, count, dtype, op, comm, module, segsize);
147148
case (6):
148149
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, module);
150+
case (7):
151+
return ompi_coll_base_allreduce_intra_allgather_reduce(sbuf, rbuf, count, dtype, op, comm, module);
149152
} /* switch */
150153
OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:allreduce_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
151154
algorithm, ompi_coll_tuned_forced_max_algorithms[ALLREDUCE]));

0 commit comments

Comments
 (0)