Skip to content

Commit 6564419

Browse files
authored
Merge pull request #12337 from juntangc/allgather-bruck
mca/coll: adding bruck method with any fanout k for Allgather/Allreduce
2 parents 2f2bf58 + 0c4ff68 commit 6564419

8 files changed

+587
-293
lines changed

ompi/mca/coll/base/coll_base_allgather.c

Lines changed: 322 additions & 173 deletions
Large diffs are not rendered by default.

ompi/mca/coll/base/coll_base_allreduce.c

Lines changed: 75 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,133 +1247,99 @@ int ompi_coll_base_allreduce_intra_redscat_allgather(
12471247
return err;
12481248
}
12491249

1250-
/**
1251-
* A greedy algorithm to exchange data among processes in the communicator via
1252-
* an allgather pattern, followed by a local reduction on each process. This
1253-
* avoids the round trip in a rooted communication pattern, e.g. reduce on the
1254-
* root and then broadcast to peers.
1250+
/*
1251+
* ompi_coll_base_allreduce_intra_allgather_reduce
1252+
*
1253+
* Function: use allgather for allreduce operation
1254+
* Accepts: Same as MPI_Allreduce()
1255+
* Returns: MPI_SUCCESS or error code
12551256
*
1256-
* This algorithm supports both commutative and non-commutative MPI operations.
1257-
* For non-commutative operations the reduction is applied to the data in the
1258-
* same rank order, e.g. rank 0, rank 1, ... rank N, on each process.
1257+
* Description: Implements allgather based allreduce aimed to improve internode
1258+
* allreduce latency: this method takes advantage of the send and
1259+
* receive can happen at the same time; first step is allgather
1260+
* operation to allow all ranks to obtain the full dataset; the second
1261+
* step is to do reduction on all ranks to get allreduce result.
12591262
*
1260-
* This algorithm benefits inter-node allreduce over a high-latency network.
1261-
* Caution is needed on larger communicators(n) and data sizes(m), which will
1262-
* result in m*n^2 total traffic and potential network congestion.
1263+
* Limitations: This method is designed for small message sizes allreduce because it
1264+
* is not efficient in terms of network bandwidth comparing
1265+
* to gather/reduce/bcast type of approach.
12631266
*/
12641267
int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf, int count,
12651268
struct ompi_datatype_t *dtype,
12661269
struct ompi_op_t *op,
12671270
struct ompi_communicator_t *comm,
12681271
mca_coll_base_module_t *module)
12691272
{
1270-
char *send_buf = (void *) sbuf;
1271-
const int comm_size = ompi_comm_size(comm);
1272-
const int rank = ompi_comm_rank(comm);
1273-
int err = MPI_SUCCESS;
1274-
ompi_request_t **reqs;
1275-
1276-
if (sbuf == MPI_IN_PLACE) {
1277-
send_buf = rbuf;
1278-
}
1279-
1280-
/* Allocate a large-enough buffer to receive from everyone else */
1281-
char *tmp_buf = NULL, *tmp_buf_raw = NULL, *tmp_recv = NULL;
1282-
ptrdiff_t lb, extent, dsize, gap = 0;
1273+
int line = -1;
1274+
char *partial_buf = NULL;
1275+
char *partial_buf_start = NULL;
1276+
char *sendtmpbuf = NULL;
1277+
char *tmpsend = NULL;
1278+
char *tmpsend_start = NULL;
1279+
int err = OMPI_SUCCESS;
1280+
1281+
ptrdiff_t extent, lb;
12831282
ompi_datatype_get_extent(dtype, &lb, &extent);
1284-
dsize = opal_datatype_span(&dtype->super, count * comm_size, &gap);
1285-
tmp_buf_raw = (char *) malloc(dsize);
1286-
if (NULL == tmp_buf_raw) {
1287-
return OMPI_ERR_OUT_OF_RESOURCE;
1288-
}
12891283

1290-
tmp_buf = tmp_buf_raw - gap;
1291-
1292-
/* Requests for send to AND receive from everyone else */
1293-
int reqs_needed = (comm_size - 1) * 2;
1294-
reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed);
1295-
1296-
const ptrdiff_t incr = extent * count;
1297-
1298-
/* Exchange data with peer processes, excluding self */
1299-
int req_index = 0, peer_rank = 0;
1300-
for (int i = 1; i < comm_size; ++i) {
1301-
/* Start at the next rank */
1302-
peer_rank = (rank + i) % comm_size;
1303-
1304-
/* Prepare for the next receive buffer */
1305-
if (0 == peer_rank && rbuf != send_buf) {
1306-
/* Optimization for Rank 0 - its data will always be placed at the beginning of local
1307-
* reduce output buffer.
1308-
*/
1309-
tmp_recv = rbuf;
1310-
} else {
1311-
tmp_recv = tmp_buf + (peer_rank * incr);
1312-
}
1313-
1314-
err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
1315-
comm, &reqs[req_index++]));
1316-
if (MPI_SUCCESS != err) {
1317-
goto err_hndl;
1318-
}
1284+
int rank = ompi_comm_rank(comm);
1285+
int size = ompi_comm_size(comm);
13191286

1320-
err = MCA_PML_CALL(isend(send_buf, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE,
1321-
MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index++]));
1322-
if (MPI_SUCCESS != err) {
1323-
goto err_hndl;
1324-
}
1287+
sendtmpbuf = (char*) sbuf;
1288+
if( sbuf == MPI_IN_PLACE ) {
1289+
sendtmpbuf = (char *)rbuf;
13251290
}
1326-
1327-
err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE);
1328-
1329-
/**
1330-
* Prepare for local reduction by moving Rank 0's data to rbuf.
1331-
* Previously we tried to receive Rank 0's data in rbuf, but we need to handle
1332-
* the following special cases.
1333-
*/
1334-
if (0 != rank && rbuf == send_buf) {
1335-
/* For inplace reduction copy out the send_buf before moving Rank 0's data */
1336-
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) tmp_buf + (rank * incr),
1337-
send_buf);
1338-
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) tmp_buf);
1339-
} else if (0 == rank && rbuf != send_buf) {
1340-
/* For Rank 0 we need to copy the send_buf to rbuf manually */
1341-
ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf);
1291+
ptrdiff_t buf_size, gap = 0;
1292+
buf_size = opal_datatype_span(&dtype->super, (int64_t)count * size, &gap);
1293+
partial_buf = (char *) malloc(buf_size);
1294+
partial_buf_start = partial_buf - gap;
1295+
buf_size = opal_datatype_span(&dtype->super, (int64_t)count, &gap);
1296+
tmpsend = (char *) malloc(buf_size);
1297+
tmpsend_start = tmpsend - gap;
1298+
1299+
err = ompi_datatype_copy_content_same_ddt(dtype, count,
1300+
(char*)tmpsend_start,
1301+
(char*)sendtmpbuf);
1302+
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
1303+
1304+
// apply allgather data so that each rank has a full copy to do reduce (trade bandwidth for better latency)
1305+
err = comm->c_coll->coll_allgather(tmpsend_start, count, dtype,
1306+
partial_buf_start, count, dtype,
1307+
comm, comm->c_coll->coll_allgather_module);
1308+
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
1309+
1310+
for (int target = 1; target < size; target++) {
1311+
ompi_op_reduce(op,
1312+
partial_buf_start + (ptrdiff_t)target * count * extent,
1313+
partial_buf_start,
1314+
count,
1315+
dtype);
13421316
}
13431317

1344-
/* Now do local reduction - Rank 0's data is already in rbuf so start from Rank 1 */
1345-
char *inbuf = NULL;
1346-
for (peer_rank = 1; peer_rank < comm_size; peer_rank++) {
1347-
if (rank == peer_rank && rbuf != send_buf) {
1348-
inbuf = send_buf;
1349-
} else {
1350-
inbuf = tmp_buf + (peer_rank * incr);
1351-
}
1352-
ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype);
1353-
}
1318+
// move data to rbuf
1319+
err = ompi_datatype_copy_content_same_ddt(dtype, count,
1320+
(char*)rbuf,
1321+
(char*)partial_buf_start);
1322+
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
13541323

1355-
err_hndl:
1356-
if (NULL != tmp_buf_raw)
1357-
free(tmp_buf_raw);
1324+
if (NULL != partial_buf) free(partial_buf);
1325+
if (NULL != tmpsend) free(tmpsend);
1326+
return MPI_SUCCESS;
13581327

1359-
if (NULL != reqs) {
1360-
if (MPI_ERR_IN_STATUS == err) {
1361-
for (int i = 0; i < reqs_needed; i++) {
1362-
if (MPI_REQUEST_NULL == reqs[i])
1363-
continue;
1364-
if (MPI_ERR_PENDING == reqs[i]->req_status.MPI_ERROR)
1365-
continue;
1366-
if (MPI_SUCCESS != reqs[i]->req_status.MPI_ERROR) {
1367-
err = reqs[i]->req_status.MPI_ERROR;
1368-
break;
1369-
}
1370-
}
1371-
}
1372-
ompi_coll_base_free_reqs(reqs, reqs_needed);
1328+
err_hndl:
1329+
if (NULL != partial_buf) {
1330+
free(partial_buf);
1331+
partial_buf = NULL;
1332+
partial_buf_start = NULL;
13731333
}
1374-
1375-
/* All done */
1334+
if (NULL != tmpsend) {
1335+
free(tmpsend);
1336+
tmpsend = NULL;
1337+
tmpsend_start = NULL;
1338+
}
1339+
OPAL_OUTPUT((ompi_coll_base_framework.framework_output, "%s:%4d\tError occurred %d, rank %2d",
1340+
__FILE__, line, err, rank));
1341+
(void)line; // silence compiler warning
13761342
return err;
1377-
}
13781343

1344+
}
13791345
/* copied function (with appropriate renaming) ends here */

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,14 @@ typedef enum COLLTYPE {
187187
BEGIN_C_DECLS
188188

189189
/* All Gather */
190-
int ompi_coll_base_allgather_intra_bruck(ALLGATHER_ARGS);
191190
int ompi_coll_base_allgather_intra_recursivedoubling(ALLGATHER_ARGS);
192191
int ompi_coll_base_allgather_intra_sparbit(ALLGATHER_ARGS);
193192
int ompi_coll_base_allgather_intra_ring(ALLGATHER_ARGS);
194193
int ompi_coll_base_allgather_intra_neighborexchange(ALLGATHER_ARGS);
195194
int ompi_coll_base_allgather_intra_basic_linear(ALLGATHER_ARGS);
196195
int ompi_coll_base_allgather_intra_two_procs(ALLGATHER_ARGS);
196+
int ompi_coll_base_allgather_intra_k_bruck(ALLGATHER_ARGS, int radix);
197+
int ompi_coll_base_allgather_direct_messaging(ALLGATHER_ARGS);
197198

198199
/* All GatherV */
199200
int ompi_coll_base_allgatherv_intra_bruck(ALLGATHERV_ARGS);
@@ -274,6 +275,7 @@ int ompi_coll_base_reduce_intra_binary(REDUCE_ARGS, uint32_t segsize, int max_ou
274275
int ompi_coll_base_reduce_intra_binomial(REDUCE_ARGS, uint32_t segsize, int max_outstanding_reqs );
275276
int ompi_coll_base_reduce_intra_in_order_binary(REDUCE_ARGS, uint32_t segsize, int max_outstanding_reqs );
276277
int ompi_coll_base_reduce_intra_redscat_gather(REDUCE_ARGS);
278+
int ompi_coll_base_reduce_intra_knomial(REDUCE_ARGS, uint32_t segsize, int max_outstanding_reqs, int radix);
277279

278280
/* Reduce_scatter */
279281
int ompi_coll_base_reduce_scatter_intra_nonoverlapping(REDUCESCATTER_ARGS);

0 commit comments

Comments
 (0)