Skip to content

Commit d38ac4c

Browse files
authored
Merge pull request #9956 from jjhursey/fix-large-payload-reduce-scatter-block
Fix base reduce_scatter_block for large payloads
2 parents ee0d20c + 8167468 commit d38ac4c

File tree

1 file changed

+125
-43
lines changed

1 file changed

+125
-43
lines changed

ompi/mca/coll/base/coll_base_reduce_scatter_block.c

Lines changed: 125 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* and Technology (RIST). All rights reserved.
1818
* Copyright (c) 2018 Siberian State University of Telecommunications
1919
* and Information Sciences. All rights reserved.
20+
* Copyright (c) 2022 IBM Corporation. All rights reserved.
2021
* $COPYRIGHT$
2122
*
2223
* Additional copyrights may follow
@@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
5859
struct ompi_communicator_t *comm,
5960
mca_coll_base_module_t *module)
6061
{
61-
int rank, size, count, err = OMPI_SUCCESS;
62+
int rank, size, err = OMPI_SUCCESS;
63+
size_t count;
6264
ptrdiff_t gap, span;
6365
char *recv_buf = NULL, *recv_buf_free = NULL;
6466

@@ -67,40 +69,106 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
6769
size = ompi_comm_size(comm);
6870

6971
/* short cut the trivial case */
70-
count = rcount * size;
72+
count = rcount * (size_t)size;
7173
if (0 == count) {
7274
return OMPI_SUCCESS;
7375
}
7476

75-
/* get datatype information */
76-
span = opal_datatype_span(&dtype->super, count, &gap);
77-
7877
/* Handle MPI_IN_PLACE */
7978
if (MPI_IN_PLACE == sbuf) {
8079
sbuf = rbuf;
8180
}
8281

83-
if (0 == rank) {
84-
/* temporary receive buffer. See coll_basic_reduce.c for
85-
details on sizing */
86-
recv_buf_free = (char*) malloc(span);
87-
if (NULL == recv_buf_free) {
88-
err = OMPI_ERR_OUT_OF_RESOURCE;
89-
goto cleanup;
82+
/*
83+
* For large payload (defined as a count greater than INT_MAX)
84+
* to reduce the memory footprint on the root we segment the
85+
* reductions per rank, then send to each rank.
86+
*
87+
* Additionally, sending the message in the coll_reduce() as
88+
* "rcount*size" would exceed the 'int count' parameter in the
89+
* coll_reduce() function. So another technique is required
90+
* for count values that exceed INT_MAX.
91+
*/
92+
if ( OPAL_UNLIKELY(count > INT_MAX) ) {
93+
int i;
94+
void *sbuf_ptr;
95+
96+
/* Get datatype information for an individual block */
97+
span = opal_datatype_span(&dtype->super, rcount, &gap);
98+
99+
if (0 == rank) {
100+
/* temporary receive buffer. See coll_basic_reduce.c for
101+
details on sizing */
102+
recv_buf_free = (char*) malloc(span);
103+
if (NULL == recv_buf_free) {
104+
err = OMPI_ERR_OUT_OF_RESOURCE;
105+
goto cleanup;
106+
}
107+
recv_buf = recv_buf_free - gap;
108+
}
109+
110+
for( i = 0; i < size; ++i ) {
111+
/* Calculate the portion of the send buffer to reduce over */
112+
sbuf_ptr = (char*)sbuf + span * (size_t)i;
113+
114+
/* Reduction for this peer */
115+
err = comm->c_coll->coll_reduce(sbuf_ptr, recv_buf, rcount,
116+
dtype, op, 0, comm,
117+
comm->c_coll->coll_reduce_module);
118+
if (MPI_SUCCESS != err) {
119+
goto cleanup;
120+
}
121+
122+
/* Send reduce results to this peer */
123+
if (0 == rank ) {
124+
if( i == rank ) {
125+
err = ompi_datatype_copy_content_same_ddt(dtype, rcount, rbuf, recv_buf);
126+
} else {
127+
err = MCA_PML_CALL(send(recv_buf, rcount, dtype, i,
128+
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
129+
MCA_PML_BASE_SEND_STANDARD, comm));
130+
}
131+
if (MPI_SUCCESS != err) {
132+
goto cleanup;
133+
}
134+
}
135+
else if( i == rank ) {
136+
err = MCA_PML_CALL(recv(rbuf, rcount, dtype, 0,
137+
MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK,
138+
comm, MPI_STATUS_IGNORE));
139+
if (MPI_SUCCESS != err) {
140+
goto cleanup;
141+
}
142+
}
90143
}
91-
recv_buf = recv_buf_free - gap;
92144
}
145+
else {
146+
/* get datatype information */
147+
span = opal_datatype_span(&dtype->super, count, &gap);
148+
149+
if (0 == rank) {
150+
/* temporary receive buffer. See coll_basic_reduce.c for
151+
details on sizing */
152+
recv_buf_free = (char*) malloc(span);
153+
if (NULL == recv_buf_free) {
154+
err = OMPI_ERR_OUT_OF_RESOURCE;
155+
goto cleanup;
156+
}
157+
recv_buf = recv_buf_free - gap;
158+
}
93159

94-
/* reduction */
95-
err =
96-
comm->c_coll->coll_reduce(sbuf, recv_buf, count, dtype, op, 0,
97-
comm, comm->c_coll->coll_reduce_module);
160+
/* reduction */
161+
err =
162+
comm->c_coll->coll_reduce(sbuf, recv_buf, (int)count, dtype, op, 0,
163+
comm, comm->c_coll->coll_reduce_module);
164+
if (MPI_SUCCESS != err) {
165+
goto cleanup;
166+
}
98167

99-
/* scatter */
100-
if (MPI_SUCCESS == err) {
168+
/* scatter */
101169
err = comm->c_coll->coll_scatter(recv_buf, rcount, dtype,
102-
rbuf, rcount, dtype, 0,
103-
comm, comm->c_coll->coll_scatter_module);
170+
rbuf, rcount, dtype, 0,
171+
comm, comm->c_coll->coll_scatter_module);
104172
}
105173

106174
cleanup:
@@ -146,7 +214,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
146214
if (comm_size < 2)
147215
return MPI_SUCCESS;
148216

149-
totalcount = comm_size * rcount;
217+
totalcount = comm_size * (size_t)rcount;
218+
if( OPAL_UNLIKELY(totalcount > INT_MAX) ) {
219+
/*
220+
* Large payload collectives are not supported by this algorithm.
221+
* The blocklens and displs calculations in the loop below
222+
* will overflow an int data type.
223+
* Fallback to the linear algorithm.
224+
*/
225+
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype, op, comm, module);
226+
}
150227
ompi_datatype_type_extent(dtype, &extent);
151228
span = opal_datatype_span(&dtype->super, totalcount, &gap);
152229
tmpbuf_raw = malloc(span);
@@ -347,7 +424,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
347424
return ompi_coll_base_reduce_scatter_block_basic_linear(sbuf, rbuf, rcount, dtype,
348425
op, comm, module);
349426
}
350-
totalcount = comm_size * rcount;
427+
428+
totalcount = comm_size * (size_t)rcount;
351429
ompi_datatype_type_extent(dtype, &extent);
352430
span = opal_datatype_span(&dtype->super, totalcount, &gap);
353431
tmpbuf_raw = malloc(span);
@@ -431,22 +509,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
431509
* have their result calculated by the process to their
432510
* right (rank + 1).
433511
*/
434-
int send_count = 0, recv_count = 0;
512+
size_t send_count = 0, recv_count = 0;
435513
if (vrank < vpeer) {
436514
/* Send the right half of the buffer, recv the left half */
437515
send_index = recv_index + mask;
438-
send_count = rcount * ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
439-
recv_count = rcount * ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
516+
send_count = rcount * (size_t)ompi_range_sum(send_index, last_index - 1, nprocs_rem - 1);
517+
recv_count = rcount * (size_t)ompi_range_sum(recv_index, send_index - 1, nprocs_rem - 1);
440518
} else {
441519
/* Send the left half of the buffer, recv the right half */
442520
recv_index = send_index + mask;
443-
send_count = rcount * ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
444-
recv_count = rcount * ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
521+
send_count = rcount * (size_t)ompi_range_sum(send_index, recv_index - 1, nprocs_rem - 1);
522+
recv_count = rcount * (size_t)ompi_range_sum(recv_index, last_index - 1, nprocs_rem - 1);
445523
}
446-
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
447-
2 * recv_index : nprocs_rem + recv_index);
448-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
449-
2 * send_index : nprocs_rem + send_index);
524+
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
525+
2 * recv_index : nprocs_rem + recv_index);
526+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
527+
2 * send_index : nprocs_rem + send_index);
450528
struct ompi_request_t *request = NULL;
451529

452530
if (recv_count > 0) {
@@ -587,7 +665,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
587665
sbuf, rbuf, rcount, dtype, op, comm, module);
588666
}
589667

590-
totalcount = comm_size * rcount;
668+
totalcount = comm_size * (size_t)rcount;
591669
ompi_datatype_type_extent(dtype, &extent);
592670
span = opal_datatype_span(&dtype->super, totalcount, &gap);
593671
tmpbuf[0] = malloc(span);
@@ -677,13 +755,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
677755
/* Send the upper half of reduction buffer, recv the lower half */
678756
recv_index += nblocks;
679757
}
680-
int send_count = rcount * ompi_range_sum(send_index,
681-
send_index + nblocks - 1, nprocs_rem - 1);
682-
int recv_count = rcount * ompi_range_sum(recv_index,
683-
recv_index + nblocks - 1, nprocs_rem - 1);
684-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
758+
size_t send_count = rcount *
759+
(size_t)ompi_range_sum(send_index,
760+
send_index + nblocks - 1,
761+
nprocs_rem - 1);
762+
size_t recv_count = rcount *
763+
(size_t)ompi_range_sum(recv_index,
764+
recv_index + nblocks - 1,
765+
nprocs_rem - 1);
766+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
685767
2 * send_index : nprocs_rem + send_index);
686-
ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1) ?
768+
ptrdiff_t rdispl = rcount * (size_t)((recv_index <= nprocs_rem - 1) ?
687769
2 * recv_index : nprocs_rem + recv_index);
688770

689771
err = ompi_coll_base_sendrecv(psend + (ptrdiff_t)sdispl * extent, send_count,
@@ -719,7 +801,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
719801
* Process has two blocks: for excluded process and own.
720802
* Send result to the excluded process.
721803
*/
722-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
804+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
723805
2 * send_index : nprocs_rem + send_index);
724806
err = MCA_PML_CALL(send(psend + (ptrdiff_t)sdispl * extent,
725807
rcount, dtype, peer - 1,
@@ -729,7 +811,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
729811
}
730812

731813
/* Send result to a remote process according to a mirror permutation */
732-
ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1) ?
814+
ptrdiff_t sdispl = rcount * (size_t)((send_index <= nprocs_rem - 1) ?
733815
2 * send_index : nprocs_rem + send_index);
734816
/* If process has two blocks, then send the second block (own block) */
735817
if (vpeer < nprocs_rem)
@@ -821,7 +903,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
821903
if (rcount == 0 || comm_size < 2)
822904
return MPI_SUCCESS;
823905

824-
totalcount = comm_size * rcount;
906+
totalcount = comm_size * (size_t)rcount;
825907
ompi_datatype_type_extent(dtype, &extent);
826908
span = opal_datatype_span(&dtype->super, totalcount, &gap);
827909
tmpbuf[0] = malloc(span);
@@ -843,7 +925,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
843925
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
844926
}
845927

846-
int nblocks = totalcount, send_index = 0, recv_index = 0;
928+
size_t nblocks = totalcount, send_index = 0, recv_index = 0;
847929
for (int mask = 1; mask < comm_size; mask <<= 1) {
848930
int peer = rank ^ mask;
849931
nblocks /= 2;

0 commit comments

Comments
 (0)