|
18 | 18 | * Copyright (c) 2018 Siberian State University of Telecommunications
|
19 | 19 | * and Information Science. All rights reserved.
|
20 | 20 | * Copyright (c) 2022 Cisco Systems, Inc. All rights reserved.
|
| 21 | + * Copyright (c) Amazon.com, Inc. or its affiliates. |
| 22 | + * All rights reserved. |
21 | 23 | * $COPYRIGHT$
|
22 | 24 | *
|
23 | 25 | * Additional copyrights may follow
|
@@ -1245,4 +1247,116 @@ int ompi_coll_base_allreduce_intra_redscat_allgather(
|
1245 | 1247 | return err;
|
1246 | 1248 | }
|
1247 | 1249 |
|
| 1250 | +/** |
| 1251 | + * A greedy algorithm to exchange data among processes in the communicator via |
| 1252 | + * an allgather pattern, followed by a local reduction on each process. This |
| 1253 | + * avoids the round trip in a rooted communication pattern, e.g. reduce on the |
| 1254 | + * root and then broadcast to peers. |
| 1255 | + * |
| 1256 | + * This algorithm supports both commutative and non-commutative MPI operations. |
| 1257 | + * For non-commutative operations the reduction is applied to the data in the |
| 1258 | + * same rank order, e.g. rank 0, rank 1, ... rank N, on each process. |
| 1259 | + * |
| 1260 | + * This algorithm benefits inter-node allreduce over a high-latency network. |
| 1261 | + * Caution is needed on larger communicators(n) and data sizes(m), which will |
| 1262 | + * result in m*n^2 total traffic and potential network congestion. |
| 1263 | + */ |
| 1264 | +int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf, int count, |
| 1265 | + struct ompi_datatype_t *dtype, |
| 1266 | + struct ompi_op_t *op, |
| 1267 | + struct ompi_communicator_t *comm, |
| 1268 | + mca_coll_base_module_t *module) |
| 1269 | +{ |
| 1270 | + char *send_buf = (void *) sbuf; |
| 1271 | + int comm_size = ompi_comm_size(comm); |
| 1272 | + int err = MPI_SUCCESS; |
| 1273 | + int rank = ompi_comm_rank(comm); |
| 1274 | + bool commutative = ompi_op_is_commute(op); |
| 1275 | + ompi_request_t **reqs; |
| 1276 | + |
| 1277 | + if (sbuf == MPI_IN_PLACE) { |
| 1278 | + send_buf = rbuf; |
| 1279 | + } |
| 1280 | + |
| 1281 | + /* Allocate a large-enough buffer to receive from everyone else */ |
| 1282 | + char *tmp_buf = NULL, *tmp_buf_raw = NULL, *tmp_recv = NULL; |
| 1283 | + ptrdiff_t lb, extent, dsize, gap = 0; |
| 1284 | + ompi_datatype_get_extent(dtype, &lb, &extent); |
| 1285 | + dsize = opal_datatype_span(&dtype->super, count * comm_size, &gap); |
| 1286 | + tmp_buf_raw = (char *) malloc(dsize); |
| 1287 | + if (NULL == tmp_buf_raw) { |
| 1288 | + return OMPI_ERR_OUT_OF_RESOURCE; |
| 1289 | + } |
| 1290 | + |
| 1291 | + if (commutative) { |
| 1292 | + ompi_datatype_copy_content_same_ddt(dtype, count, (char *) rbuf, (char *) send_buf); |
| 1293 | + } |
| 1294 | + |
| 1295 | + tmp_buf = tmp_buf_raw - gap; |
| 1296 | + |
| 1297 | + /* Requests for send to AND receive from everyone else */ |
| 1298 | + int reqs_needed = (comm_size - 1) * 2; |
| 1299 | + reqs = ompi_coll_base_comm_get_reqs(module->base_data, reqs_needed); |
| 1300 | + |
| 1301 | + ptrdiff_t incr = extent * count; |
| 1302 | + tmp_recv = (char *) tmp_buf; |
| 1303 | + |
| 1304 | + /* Exchange data with peer processes */ |
| 1305 | + int req_index = 0, peer_rank = 0; |
| 1306 | + for (int i = 1; i < comm_size; ++i) { |
| 1307 | + peer_rank = (rank + i) % comm_size; |
| 1308 | + tmp_recv = tmp_buf + (peer_rank * incr); |
| 1309 | + err = MCA_PML_CALL(irecv(tmp_recv, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE, |
| 1310 | + comm, &reqs[req_index++])); |
| 1311 | + if (MPI_SUCCESS != err) { |
| 1312 | + goto err_hndl; |
| 1313 | + } |
| 1314 | + |
| 1315 | + err = MCA_PML_CALL(isend(send_buf, count, dtype, peer_rank, MCA_COLL_BASE_TAG_ALLREDUCE, |
| 1316 | + MCA_PML_BASE_SEND_STANDARD, comm, &reqs[req_index++])); |
| 1317 | + if (MPI_SUCCESS != err) { |
| 1318 | + goto err_hndl; |
| 1319 | + } |
| 1320 | + } |
| 1321 | + |
| 1322 | + err = ompi_request_wait_all(req_index, reqs, MPI_STATUSES_IGNORE); |
| 1323 | + |
| 1324 | + /* Prepare for local reduction */ |
| 1325 | + peer_rank = 0; |
| 1326 | + if (!commutative) { |
| 1327 | + /* For non-commutative operations, ensure the reduction always starts from Rank 0's data */ |
| 1328 | + memcpy(rbuf, 0 == rank ? send_buf : tmp_buf, incr); |
| 1329 | + peer_rank = 1; |
| 1330 | + } |
| 1331 | + |
| 1332 | + char *inbuf; |
| 1333 | + for (; peer_rank < comm_size; peer_rank++) { |
| 1334 | + inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr); |
| 1335 | + ompi_op_reduce(op, (void *) inbuf, rbuf, count, dtype); |
| 1336 | + } |
| 1337 | + |
| 1338 | +err_hndl: |
| 1339 | + if (NULL != tmp_buf_raw) |
| 1340 | + free(tmp_buf_raw); |
| 1341 | + |
| 1342 | + if (NULL != reqs) { |
| 1343 | + if (MPI_ERR_IN_STATUS == err) { |
| 1344 | + for (int i = 0; i < reqs_needed; i++) { |
| 1345 | + if (MPI_REQUEST_NULL == reqs[i]) |
| 1346 | + continue; |
| 1347 | + if (MPI_ERR_PENDING == reqs[i]->req_status.MPI_ERROR) |
| 1348 | + continue; |
| 1349 | + if (MPI_SUCCESS != reqs[i]->req_status.MPI_ERROR) { |
| 1350 | + err = reqs[i]->req_status.MPI_ERROR; |
| 1351 | + break; |
| 1352 | + } |
| 1353 | + } |
| 1354 | + } |
| 1355 | + ompi_coll_base_free_reqs(reqs, reqs_needed); |
| 1356 | + } |
| 1357 | + |
| 1358 | + /* All done */ |
| 1359 | + return err; |
| 1360 | +} |
| 1361 | + |
1248 | 1362 | /* copied function (with appropriate renaming) ends here */
|
0 commit comments