Skip to content

Commit d6d175e

Browse files
authored
Merge pull request #11871 from wenduwan/allgather_reduce
coll/base,tuned: introduce allgather_reduce allreduce algorithm
2 parents 6aa55b8 + 5a7f814 commit d6d175e

File tree

3 files changed

+118
-0
lines changed

3 files changed

+118
-0
lines changed

ompi/mca/coll/base/coll_base_allreduce.c

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
* Copyright (c) 2018 Siberian State University of Telecommunications
1919
* and Information Science. All rights reserved.
2020
* Copyright (c) 2022 Cisco Systems, Inc. All rights reserved.
21+
* Copyright (c) Amazon.com, Inc. or its affiliates.
22+
* All rights reserved.
2123
* $COPYRIGHT$
2224
*
2325
* Additional copyrights may follow
@@ -1245,4 +1247,116 @@ int ompi_coll_base_allreduce_intra_redscat_allgather(
12451247
return err;
12461248
}
12471249

1250+
/**
1251+
* A greedy algorithm to exchange data among processes in the communicator via
1252+
* an allgather pattern, followed by a local reduction on each process. This
1253+
* avoids the round trip in a rooted communication pattern, e.g. reduce on the
1254+
* root and then broadcast to peers.
1255+
*
1256+
* This algorithm supports both commutative and non-commutative MPI operations.
1257+
* For non-commutative operations the reduction is applied to the data in the
1258+
* same rank order, e.g. rank 0, rank 1, ... rank N, on each process.
1259+
*
1260+
* This algorithm benefits inter-node allreduce over a high-latency network.
1261+
* Caution is needed on larger communicators(n) and data sizes(m), which will
1262+
* result in m*n^2 total traffic and potential network congestion.
1263+
*/
1264+
int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf, int count,
1265+
struct ompi_datatype_t *dtype,
1266+
struct ompi_op_t *op,
1267+
struct ompi_communicator_t *comm,
1268+
mca_coll_base_module_t *module)
1269+
{
1270+
char *send_buf = (void *) sbuf;
1271+
int comm_size = ompi_comm_size(comm);
1272+
int err = MPI_SUCCESS;
1273+
int rank = ompi_comm_rank(comm);
1274+
bool commutative = ompi_op_is_commute(op);
1275+
ompi_request_t **reqs;
1276+
1277+
if (sbuf == MPI_IN_PLACE) {
1278+
send_buf = rbuf;
1279+
}
1280+
1281+
/* Allocate a large-enough buffer to receive from everyone else */
1282+
char *tmp_buf = NULL, *tmp_buf_raw = NULL, *tmp_recv = NULL;
1283+
ptrdiff_t lb, extent, dsize, gap = 0;
1284+
ompi_datatype_get_extent(dtype, &lb, &extent);
1285+
dsize = opal_datatype_span(&dtype->super, count * comm_size, &gap);
1286+
tmp_buf_raw = (char *) malloc(dsize);
1287+
if (NULL == tmp_buf_raw) {
1288+
return OMPI_ERR_OUT_OF_RESOURCE;
1289+
}
1290+
1291+
if (commutative) {
1292+
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
1293+
}
1294+
1295+
tmp_buf = tmp_buf_raw - gap;
1296+
1297+
/* Requests for send to AND receive from everyone else */
1298+
int reqs_needed = (comm_size - 1) * 2;
1299+
reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed);
1300+
1301+
ptrdiff_t incr = extent * count;
1302+
tmp_recv = (char *) tmp_buf;
1303+
1304+
/* Exchange data with peer processes */
1305+
int req_index = 0, peer_rank = 0;
1306+
for (int i = 1; i < comm_size; ++i) {
1307+
peer_rank = (rank + i) % comm_size;
1308+
tmp_recv = tmp_buf + (peer_rank * incr);
1309+
err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
1310+
comm, &reqs[req_index++]));
1311+
if (MPI_SUCCESS != err) {
1312+
goto err_hndl;
1313+
}
1314+
1315+
err = MCA_PML_CALL(isend(send_buf, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
1316+
MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index++]));
1317+
if (MPI_SUCCESS != err) {
1318+
goto err_hndl;
1319+
}
1320+
}
1321+
1322+
err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE);
1323+
1324+
/* Prepare for local reduction */
1325+
peer_rank = 0;
1326+
if (!commutative) {
1327+
/* For non-commutative operations, ensure the reduction always starts from Rank 0's data */
1328+
memcpy(rbuf, 0 == rank ? send_buf : tmp_buf, incr);
1329+
peer_rank = 1;
1330+
}
1331+
1332+
char *inbuf;
1333+
for (; peer_rank < comm_size; peer_rank++) {
1334+
inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr);
1335+
ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype);
1336+
}
1337+
1338+
err_hndl:
1339+
if (NULL != tmp_buf_raw)
1340+
free(tmp_buf_raw);
1341+
1342+
if (NULL != reqs) {
1343+
if (MPI_ERR_IN_STATUS == err) {
1344+
for (int i = 0; i < reqs_needed; i++) {
1345+
if (MPI_REQUEST_NULL == reqs[i])
1346+
continue;
1347+
if (MPI_ERR_PENDING == reqs[i]->req_status.MPI_ERROR)
1348+
continue;
1349+
if (MPI_SUCCESS != reqs[i]->req_status.MPI_ERROR) {
1350+
err = reqs[i]->req_status.MPI_ERROR;
1351+
break;
1352+
}
1353+
}
1354+
}
1355+
ompi_coll_base_free_reqs(reqs, reqs_needed);
1356+
}
1357+
1358+
/* All done */
1359+
return err;
1360+
}
1361+
12481362
/* copied function (with appropriate renaming) ends here */

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ int ompi_coll_base_allreduce_intra_ring(ALLREDUCE_ARGS);
210210
int ompi_coll_base_allreduce_intra_ring_segmented(ALLREDUCE_ARGS, uint32_t segsize);
211211
int ompi_coll_base_allreduce_intra_basic_linear(ALLREDUCE_ARGS);
212212
int ompi_coll_base_allreduce_intra_redscat_allgather(ALLREDUCE_ARGS);
213+
int ompi_coll_base_allreduce_intra_allgather_reduce(ALLREDUCE_ARGS);
213214

214215
/* AlltoAll */
215216
int ompi_coll_base_alltoall_intra_pairwise(ALLTOALL_ARGS);

ompi/mca/coll/tuned/coll_tuned_allreduce_decision.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ static const mca_base_var_enum_value_t allreduce_algorithms[] = {
4242
{4, "ring"},
4343
{5, "segmented_ring"},
4444
{6, "rabenseifner"},
45+
{7, "allgather_reduce"},
4546
{0, NULL}
4647
};
4748

@@ -146,6 +147,8 @@ int ompi_coll_tuned_allreduce_intra_do_this(const void *sbuf, void *rbuf, int co
146147
return ompi_coll_base_allreduce_intra_ring_segmented(sbuf, rbuf, count, dtype, op, comm, module, segsize);
147148
case (6):
148149
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm, module);
150+
case (7):
151+
return ompi_coll_base_allreduce_intra_allgather_reduce(sbuf, rbuf, count, dtype, op, comm, module);
149152
} /* switch */
150153
OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:allreduce_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
151154
algorithm, ompi_coll_tuned_forced_max_algorithms[ALLREDUCE]));

0 commit comments

Comments
 (0)