Skip to content

Commit e3323e0

Browse files
authored
Merge pull request #12258 from wenduwan/allreduce_gather_commutative
coll/base: correctly handle non-commutative ops in allgather-reduce allreduce
2 parents 3ae723f + 314f43b commit e3323e0

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

ompi/mca/coll/base/coll_base_allreduce.c

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,10 +1268,9 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
12681268
mca_coll_base_module_t *module)
12691269
{
12701270
char *send_buf = (void *) sbuf;
1271-
int comm_size = ompi_comm_size(comm);
1271+
const int comm_size = ompi_comm_size(comm);
1272+
const int rank = ompi_comm_rank(comm);
12721273
int err = MPI_SUCCESS;
1273-
int rank = ompi_comm_rank(comm);
1274-
bool commutative = ompi_op_is_commute(op);
12751274
ompi_request_t **reqs;
12761275

12771276
if (sbuf == MPI_IN_PLACE) {
@@ -1288,24 +1287,30 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
12881287
return OMPI_ERR_OUT_OF_RESOURCE;
12891288
}
12901289

1291-
if (commutative) {
1292-
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
1293-
}
1294-
12951290
tmp_buf = tmp_buf_raw - gap;
12961291

12971292
/* Requests for send to AND receive from everyone else */
12981293
int reqs_needed = (comm_size - 1) * 2;
12991294
reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed);
13001295

1301-
ptrdiff_t incr = extent * count;
1302-
tmp_recv = (char *) tmp_buf;
1296+
const ptrdiff_t incr = extent * count;
13031297

1304-
/* Exchange data with peer processes */
1298+
/* Exchange data with peer processes, excluding self */
13051299
int req_index = 0, peer_rank = 0;
13061300
for (int i = 1; i < comm_size; ++i) {
1301+
/* Start at the next rank */
13071302
peer_rank = (rank + i) % comm_size;
1308-
tmp_recv = tmp_buf + (peer_rank * incr);
1303+
1304+
/* Prepare for the next receive buffer */
1305+
if (0 == peer_rank && rbuf != send_buf) {
1306+
/* Optimization for Rank 0 - its data will always be placed at the beginning of local
1307+
* reduce output buffer.
1308+
*/
1309+
tmp_recv = rbuf;
1310+
} else {
1311+
tmp_recv = tmp_buf + (peer_rank * incr);
1312+
}
1313+
13091314
err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
13101315
comm, &reqs[req_index++]));
13111316
if (MPI_SUCCESS != err) {
@@ -1321,17 +1326,29 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
13211326

13221327
err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE);
13231328

1324-
/* Prepare for local reduction */
1325-
peer_rank = 0;
1326-
if (!commutative) {
1327-
/* For non-commutative operations, ensure the reduction always starts from Rank 0's data */
1328-
memcpy(rbuf, 0 == rank ? send_buf : tmp_buf, incr);
1329-
peer_rank = 1;
1329+
/**
1330+
* Prepare for local reduction by moving Rank 0's data to rbuf.
1331+
* Previously we tried to receive Rank 0's data in rbuf, but we need to handle
1332+
* the following special cases.
1333+
*/
1334+
if (0 != rank && rbuf == send_buf) {
1335+
/* For inplace reduction copy out the send_buf before moving Rank 0's data */
1336+
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) tmp_buf + (rank * incr),
1337+
send_buf);
1338+
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) tmp_buf);
1339+
} else if (0 == rank && rbuf != send_buf) {
1340+
/* For Rank 0 we need to copy the send_buf to rbuf manually */
1341+
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
13301342
}
13311343

1332-
char *inbuf;
1333-
for (; peer_rank < comm_size; peer_rank++) {
1334-
inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr);
1344+
/* Now do local reduction - Rank 0's data is already in rbuf so start from Rank 1 */
1345+
char *inbuf = NULL;
1346+
for (peer_rank = 1; peer_rank < comm_size; peer_rank++) {
1347+
if (rank == peer_rank && rbuf != send_buf) {
1348+
inbuf = send_buf;
1349+
} else {
1350+
inbuf = tmp_buf + (peer_rank * incr);
1351+
}
13351352
ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype);
13361353
}
13371354

0 commit comments

Comments
 (0)