17
17
* and Technology (RIST). All rights reserved.
18
18
* Copyright (c) 2018 Siberian State University of Telecommunications
19
19
* and Information Sciences. All rights reserved.
20
+ * Copyright (c) 2022 IBM Corporation. All rights reserved.
20
21
* $COPYRIGHT$
21
22
*
22
23
* Additional copyrights may follow
@@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
58
59
struct ompi_communicator_t * comm ,
59
60
mca_coll_base_module_t * module )
60
61
{
61
- int rank , size , count , err = OMPI_SUCCESS ;
62
+ int rank , size , err = OMPI_SUCCESS ;
63
+ size_t count ;
62
64
ptrdiff_t gap , span ;
63
65
char * recv_buf = NULL , * recv_buf_free = NULL ;
64
66
@@ -67,40 +69,106 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
67
69
size = ompi_comm_size (comm );
68
70
69
71
/* short cut the trivial case */
70
- count = rcount * size ;
72
+ count = rcount * ( size_t ) size ;
71
73
if (0 == count ) {
72
74
return OMPI_SUCCESS ;
73
75
}
74
76
75
- /* get datatype information */
76
- span = opal_datatype_span (& dtype -> super , count , & gap );
77
-
78
77
/* Handle MPI_IN_PLACE */
79
78
if (MPI_IN_PLACE == sbuf ) {
80
79
sbuf = rbuf ;
81
80
}
82
81
83
- if (0 == rank ) {
84
- /* temporary receive buffer. See coll_basic_reduce.c for
85
- details on sizing */
86
- recv_buf_free = (char * ) malloc (span );
87
- if (NULL == recv_buf_free ) {
88
- err = OMPI_ERR_OUT_OF_RESOURCE ;
89
- goto cleanup ;
82
+ /*
83
+ * For large payload (defined as a count greater than INT_MAX)
84
+ * to reduce the memory footprint on the root we segment the
85
+ * reductions per rank, then send to each rank.
86
+ *
87
+ * Additionally, sending the message in the coll_reduce() as
88
+ * "rcount*size" would exceed the 'int count' parameter in the
89
+ * coll_reduce() function. So another technique is required
90
+ * for count values that exceed INT_MAX.
91
+ */
92
+ if ( OPAL_UNLIKELY (count > INT_MAX ) ) {
93
+ int i ;
94
+ void * sbuf_ptr ;
95
+
96
+ /* Get datatype information for an individual block */
97
+ span = opal_datatype_span (& dtype -> super , rcount , & gap );
98
+
99
+ if (0 == rank ) {
100
+ /* temporary receive buffer. See coll_basic_reduce.c for
101
+ details on sizing */
102
+ recv_buf_free = (char * ) malloc (span );
103
+ if (NULL == recv_buf_free ) {
104
+ err = OMPI_ERR_OUT_OF_RESOURCE ;
105
+ goto cleanup ;
106
+ }
107
+ recv_buf = recv_buf_free - gap ;
108
+ }
109
+
110
+ for ( i = 0 ; i < size ; ++ i ) {
111
+ /* Calculate the portion of the send buffer to reduce over */
112
+ sbuf_ptr = (char * )sbuf + span * (size_t )i ;
113
+
114
+ /* Reduction for this peer */
115
+ err = comm -> c_coll -> coll_reduce (sbuf_ptr , recv_buf , rcount ,
116
+ dtype , op , 0 , comm ,
117
+ comm -> c_coll -> coll_reduce_module );
118
+ if (MPI_SUCCESS != err ) {
119
+ goto cleanup ;
120
+ }
121
+
122
+ /* Send reduce results to this peer */
123
+ if (0 == rank ) {
124
+ if ( i == rank ) {
125
+ err = ompi_datatype_copy_content_same_ddt (dtype , rcount , rbuf , recv_buf );
126
+ } else {
127
+ err = MCA_PML_CALL (send (recv_buf , rcount , dtype , i ,
128
+ MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK ,
129
+ MCA_PML_BASE_SEND_STANDARD , comm ));
130
+ }
131
+ if (MPI_SUCCESS != err ) {
132
+ goto cleanup ;
133
+ }
134
+ }
135
+ else if ( i == rank ) {
136
+ err = MCA_PML_CALL (recv (rbuf , rcount , dtype , 0 ,
137
+ MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK ,
138
+ comm , MPI_STATUS_IGNORE ));
139
+ if (MPI_SUCCESS != err ) {
140
+ goto cleanup ;
141
+ }
142
+ }
90
143
}
91
- recv_buf = recv_buf_free - gap ;
92
144
}
145
+ else {
146
+ /* get datatype information */
147
+ span = opal_datatype_span (& dtype -> super , count , & gap );
148
+
149
+ if (0 == rank ) {
150
+ /* temporary receive buffer. See coll_basic_reduce.c for
151
+ details on sizing */
152
+ recv_buf_free = (char * ) malloc (span );
153
+ if (NULL == recv_buf_free ) {
154
+ err = OMPI_ERR_OUT_OF_RESOURCE ;
155
+ goto cleanup ;
156
+ }
157
+ recv_buf = recv_buf_free - gap ;
158
+ }
93
159
94
- /* reduction */
95
- err =
96
- comm -> c_coll -> coll_reduce (sbuf , recv_buf , count , dtype , op , 0 ,
97
- comm , comm -> c_coll -> coll_reduce_module );
160
+ /* reduction */
161
+ err =
162
+ comm -> c_coll -> coll_reduce (sbuf , recv_buf , (int )count , dtype , op , 0 ,
163
+ comm , comm -> c_coll -> coll_reduce_module );
164
+ if (MPI_SUCCESS != err ) {
165
+ goto cleanup ;
166
+ }
98
167
99
- /* scatter */
100
- if (MPI_SUCCESS == err ) {
168
+ /* scatter */
101
169
err = comm -> c_coll -> coll_scatter (recv_buf , rcount , dtype ,
102
- rbuf , rcount , dtype , 0 ,
103
- comm , comm -> c_coll -> coll_scatter_module );
170
+ rbuf , rcount , dtype , 0 ,
171
+ comm , comm -> c_coll -> coll_scatter_module );
104
172
}
105
173
106
174
cleanup :
@@ -146,7 +214,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
146
214
if (comm_size < 2 )
147
215
return MPI_SUCCESS ;
148
216
149
- totalcount = comm_size * rcount ;
217
+ totalcount = comm_size * (size_t )rcount ;
218
+ if ( OPAL_UNLIKELY (totalcount > INT_MAX ) ) {
219
+ /*
220
+ * Large payload collectives are not supported by this algorithm.
221
+ * The blocklens and displs calculations in the loop below
222
+ * will overflow an int data type.
223
+ * Fallback to the linear algorithm.
224
+ */
225
+ return ompi_coll_base_reduce_scatter_block_basic_linear (sbuf , rbuf , rcount , dtype , op , comm , module );
226
+ }
150
227
ompi_datatype_type_extent (dtype , & extent );
151
228
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
152
229
tmpbuf_raw = malloc (span );
@@ -347,7 +424,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
347
424
return ompi_coll_base_reduce_scatter_block_basic_linear (sbuf , rbuf , rcount , dtype ,
348
425
op , comm , module );
349
426
}
350
- totalcount = comm_size * rcount ;
427
+
428
+ totalcount = comm_size * (size_t )rcount ;
351
429
ompi_datatype_type_extent (dtype , & extent );
352
430
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
353
431
tmpbuf_raw = malloc (span );
@@ -431,22 +509,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
431
509
* have their result calculated by the process to their
432
510
* right (rank + 1).
433
511
*/
434
- int send_count = 0 , recv_count = 0 ;
512
+ size_t send_count = 0 , recv_count = 0 ;
435
513
if (vrank < vpeer ) {
436
514
/* Send the right half of the buffer, recv the left half */
437
515
send_index = recv_index + mask ;
438
- send_count = rcount * ompi_range_sum (send_index , last_index - 1 , nprocs_rem - 1 );
439
- recv_count = rcount * ompi_range_sum (recv_index , send_index - 1 , nprocs_rem - 1 );
516
+ send_count = rcount * ( size_t ) ompi_range_sum (send_index , last_index - 1 , nprocs_rem - 1 );
517
+ recv_count = rcount * ( size_t ) ompi_range_sum (recv_index , send_index - 1 , nprocs_rem - 1 );
440
518
} else {
441
519
/* Send the left half of the buffer, recv the right half */
442
520
recv_index = send_index + mask ;
443
- send_count = rcount * ompi_range_sum (send_index , recv_index - 1 , nprocs_rem - 1 );
444
- recv_count = rcount * ompi_range_sum (recv_index , last_index - 1 , nprocs_rem - 1 );
521
+ send_count = rcount * ( size_t ) ompi_range_sum (send_index , recv_index - 1 , nprocs_rem - 1 );
522
+ recv_count = rcount * ( size_t ) ompi_range_sum (recv_index , last_index - 1 , nprocs_rem - 1 );
445
523
}
446
- ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1 ) ?
447
- 2 * recv_index : nprocs_rem + recv_index );
448
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
449
- 2 * send_index : nprocs_rem + send_index );
524
+ ptrdiff_t rdispl = rcount * (size_t )( (recv_index <= nprocs_rem - 1 ) ?
525
+ 2 * recv_index : nprocs_rem + recv_index );
526
+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
527
+ 2 * send_index : nprocs_rem + send_index );
450
528
struct ompi_request_t * request = NULL ;
451
529
452
530
if (recv_count > 0 ) {
@@ -587,7 +665,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
587
665
sbuf , rbuf , rcount , dtype , op , comm , module );
588
666
}
589
667
590
- totalcount = comm_size * rcount ;
668
+ totalcount = comm_size * ( size_t ) rcount ;
591
669
ompi_datatype_type_extent (dtype , & extent );
592
670
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
593
671
tmpbuf [0 ] = malloc (span );
@@ -677,13 +755,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
677
755
/* Send the upper half of reduction buffer, recv the lower half */
678
756
recv_index += nblocks ;
679
757
}
680
- int send_count = rcount * ompi_range_sum (send_index ,
681
- send_index + nblocks - 1 , nprocs_rem - 1 );
682
- int recv_count = rcount * ompi_range_sum (recv_index ,
683
- recv_index + nblocks - 1 , nprocs_rem - 1 );
684
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
758
+ size_t send_count = rcount *
759
+ (size_t )ompi_range_sum (send_index ,
760
+ send_index + nblocks - 1 ,
761
+ nprocs_rem - 1 );
762
+ size_t recv_count = rcount *
763
+ (size_t )ompi_range_sum (recv_index ,
764
+ recv_index + nblocks - 1 ,
765
+ nprocs_rem - 1 );
766
+ ptrdiff_t sdispl = rcount * (size_t )((send_index <= nprocs_rem - 1 ) ?
685
767
2 * send_index : nprocs_rem + send_index );
686
- ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1 ) ?
768
+ ptrdiff_t rdispl = rcount * (size_t )( (recv_index <= nprocs_rem - 1 ) ?
687
769
2 * recv_index : nprocs_rem + recv_index );
688
770
689
771
err = ompi_coll_base_sendrecv (psend + (ptrdiff_t )sdispl * extent , send_count ,
@@ -719,7 +801,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
719
801
* Process has two blocks: for excluded process and own.
720
802
* Send result to the excluded process.
721
803
*/
722
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
804
+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
723
805
2 * send_index : nprocs_rem + send_index );
724
806
err = MCA_PML_CALL (send (psend + (ptrdiff_t )sdispl * extent ,
725
807
rcount , dtype , peer - 1 ,
@@ -729,7 +811,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
729
811
}
730
812
731
813
/* Send result to a remote process according to a mirror permutation */
732
- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
814
+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
733
815
2 * send_index : nprocs_rem + send_index );
734
816
/* If process has two blocks, then send the second block (own block) */
735
817
if (vpeer < nprocs_rem )
@@ -821,7 +903,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
821
903
if (rcount == 0 || comm_size < 2 )
822
904
return MPI_SUCCESS ;
823
905
824
- totalcount = comm_size * rcount ;
906
+ totalcount = comm_size * ( size_t ) rcount ;
825
907
ompi_datatype_type_extent (dtype , & extent );
826
908
span = opal_datatype_span (& dtype -> super , totalcount , & gap );
827
909
tmpbuf [0 ] = malloc (span );
@@ -843,7 +925,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
843
925
if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
844
926
}
845
927
846
- int nblocks = totalcount , send_index = 0 , recv_index = 0 ;
928
+ size_t nblocks = totalcount , send_index = 0 , recv_index = 0 ;
847
929
for (int mask = 1 ; mask < comm_size ; mask <<= 1 ) {
848
930
int peer = rank ^ mask ;
849
931
nblocks /= 2 ;
0 commit comments