@@ -1268,10 +1268,9 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
1268
1268
mca_coll_base_module_t * module )
1269
1269
{
1270
1270
char * send_buf = (void * ) sbuf ;
1271
- int comm_size = ompi_comm_size (comm );
1271
+ const int comm_size = ompi_comm_size (comm );
1272
+ const int rank = ompi_comm_rank (comm );
1272
1273
int err = MPI_SUCCESS ;
1273
- int rank = ompi_comm_rank (comm );
1274
- bool commutative = ompi_op_is_commute (op );
1275
1274
ompi_request_t * * reqs ;
1276
1275
1277
1276
if (sbuf == MPI_IN_PLACE ) {
@@ -1288,24 +1287,30 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
1288
1287
return OMPI_ERR_OUT_OF_RESOURCE ;
1289
1288
}
1290
1289
1291
- if (commutative ) {
1292
- ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) rbuf , (char * ) send_buf );
1293
- }
1294
-
1295
1290
tmp_buf = tmp_buf_raw - gap ;
1296
1291
1297
1292
/* Requests for send to AND receive from everyone else */
1298
1293
int reqs_needed = (comm_size - 1 ) * 2 ;
1299
1294
reqs = ompi_coll_base_comm_get_reqs (module -> base_data , reqs_needed );
1300
1295
1301
- ptrdiff_t incr = extent * count ;
1302
- tmp_recv = (char * ) tmp_buf ;
1296
+ const ptrdiff_t incr = extent * count ;
1303
1297
1304
- /* Exchange data with peer processes */
1298
+ /* Exchange data with peer processes, excluding self */
1305
1299
int req_index = 0 , peer_rank = 0 ;
1306
1300
for (int i = 1 ; i < comm_size ; ++ i ) {
1301
+ /* Start at the next rank */
1307
1302
peer_rank = (rank + i ) % comm_size ;
1308
- tmp_recv = tmp_buf + (peer_rank * incr );
1303
+
1304
+ /* Prepare for the next receive buffer */
1305
+ if (0 == peer_rank && rbuf != send_buf ) {
1306
+ /* Optimization for Rank 0 - its data will always be placed at the beginning of local
1307
+ * reduce output buffer.
1308
+ */
1309
+ tmp_recv = rbuf ;
1310
+ } else {
1311
+ tmp_recv = tmp_buf + (peer_rank * incr );
1312
+ }
1313
+
1309
1314
err = MCA_PML_CALL (irecv (tmp_recv , count , dtype , peer_rank , MCA_COLL_BASE_TAG_ALLREDUCE ,
1310
1315
comm , & reqs [req_index ++ ]));
1311
1316
if (MPI_SUCCESS != err ) {
@@ -1321,17 +1326,29 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
1321
1326
1322
1327
err = ompi_request_wait_all (req_index , reqs , MPI_STATUSES_IGNORE );
1323
1328
1324
- /* Prepare for local reduction */
1325
- peer_rank = 0 ;
1326
- if (!commutative ) {
1327
- /* For non-commutative operations, ensure the reduction always starts from Rank 0's data */
1328
- memcpy (rbuf , 0 == rank ? send_buf : tmp_buf , incr );
1329
- peer_rank = 1 ;
1329
+ /**
1330
+ * Prepare for local reduction by moving Rank 0's data to rbuf.
1331
+ * Previously we tried to receive Rank 0's data in rbuf, but we need to handle
1332
+ * the following special cases.
1333
+ */
1334
+ if (0 != rank && rbuf == send_buf ) {
1335
+ /* For inplace reduction copy out the send_buf before moving Rank 0's data */
1336
+ ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) tmp_buf + (rank * incr ),
1337
+ send_buf );
1338
+ ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) rbuf , (char * ) tmp_buf );
1339
+ } else if (0 == rank && rbuf != send_buf ) {
1340
+ /* For Rank 0 we need to copy the send_buf to rbuf manually */
1341
+ ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) rbuf , (char * ) send_buf );
1330
1342
}
1331
1343
1332
- char * inbuf ;
1333
- for (; peer_rank < comm_size ; peer_rank ++ ) {
1334
- inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr );
1344
+ /* Now do local reduction - Rank 0's data is already in rbuf so start from Rank 1 */
1345
+ char * inbuf = NULL ;
1346
+ for (peer_rank = 1 ; peer_rank < comm_size ; peer_rank ++ ) {
1347
+ if (rank == peer_rank && rbuf != send_buf ) {
1348
+ inbuf = send_buf ;
1349
+ } else {
1350
+ inbuf = tmp_buf + (peer_rank * incr );
1351
+ }
1335
1352
ompi_op_reduce (op , (void * ) inbuf , rbuf , count , dtype );
1336
1353
}
1337
1354
0 commit comments