Skip to content

Commit a0e9b01

Browse files
authored
Merge pull request #12288 from wenduwan/fix_alltoallv_dtype_size_check
Fix alltoallv dtype size check
2 parents adaaf47 + 737eefd commit a0e9b01

File tree

2 files changed

+56
-51
lines changed

2 files changed

+56
-51
lines changed

ompi/mca/coll/base/coll_base_alltoallv.c

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
201201
int line = -1, err = 0, rank, size, step = 0, sendto, recvfrom;
202202
size_t sdtype_size, rdtype_size;
203203
void *psnd, *prcv;
204+
ompi_request_t *req;
204205
ptrdiff_t sext, rext;
205206

206207
if (MPI_IN_PLACE == sbuf) {
@@ -217,16 +218,12 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
217218
ompi_datatype_type_size(sdtype, &sdtype_size);
218219
ompi_datatype_type_size(rdtype, &rdtype_size);
219220

220-
if (0 == sdtype_size || 0 == rdtype_size) {
221-
/* Nothing to exchange */
222-
return MPI_SUCCESS;
223-
}
224-
225221
ompi_datatype_type_extent(sdtype, &sext);
226222
ompi_datatype_type_extent(rdtype, &rext);
227223

228224
/* Perform pairwise exchange starting from 1 since local exchange is done */
229225
for (step = 0; step < size; step++) {
226+
req = MPI_REQUEST_NULL;
230227

231228
/* Determine sender and receiver for this step. */
232229
sendto = (rank + step) % size;
@@ -237,12 +234,31 @@ ompi_coll_base_alltoallv_intra_pairwise(const void *sbuf, const int *scounts, co
237234
prcv = (char*)rbuf + (ptrdiff_t)rdisps[recvfrom] * rext;
238235

239236
/* send and receive */
240-
err = ompi_coll_base_sendrecv( psnd, scounts[sendto], sdtype, sendto,
241-
MCA_COLL_BASE_TAG_ALLTOALLV,
242-
prcv, rcounts[recvfrom], rdtype, recvfrom,
243-
MCA_COLL_BASE_TAG_ALLTOALLV,
244-
comm, MPI_STATUS_IGNORE, rank);
245-
if (MPI_SUCCESS != err) { line = __LINE__; goto err_hndl; }
237+
if (0 < rcounts[recvfrom] && 0 < rdtype_size) {
238+
err = MCA_PML_CALL(irecv(prcv, rcounts[recvfrom], rdtype, recvfrom,
239+
MCA_COLL_BASE_TAG_ALLTOALLV, comm, &req));
240+
if (MPI_SUCCESS != err) {
241+
line = __LINE__;
242+
goto err_hndl;
243+
}
244+
}
245+
246+
if (0 < scounts[sendto] && 0 < sdtype_size) {
247+
err = MCA_PML_CALL(send(psnd, scounts[sendto], sdtype, sendto,
248+
MCA_COLL_BASE_TAG_ALLTOALLV, MCA_PML_BASE_SEND_STANDARD, comm));
249+
if (MPI_SUCCESS != err) {
250+
line = __LINE__;
251+
goto err_hndl;
252+
}
253+
}
254+
255+
if (MPI_REQUEST_NULL != req) {
256+
err = ompi_request_wait(&req, MPI_STATUS_IGNORE);
257+
if (MPI_SUCCESS != err) {
258+
line = __LINE__;
259+
goto err_hndl;
260+
}
261+
}
246262
}
247263

248264
return MPI_SUCCESS;
@@ -293,18 +309,13 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
293309
ompi_datatype_type_size(rdtype, &rdtype_size);
294310
ompi_datatype_type_size(sdtype, &sdtype_size);
295311

296-
if (0 == rdtype_size || 0 == sdtype_size) {
297-
/* Nothing to exchange */
298-
return MPI_SUCCESS;
299-
}
300-
301312
ompi_datatype_type_extent(sdtype, &sext);
302313
ompi_datatype_type_extent(rdtype, &rext);
303314

304315
/* Simple optimization - handle send to self first */
305316
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[rank] * sext;
306317
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[rank] * rext;
307-
if (0 < scounts[rank]) {
318+
if (0 < scounts[rank] && 0 < sdtype_size) {
308319
err = ompi_datatype_sndrcv(psnd, scounts[rank], sdtype,
309320
prcv, rcounts[rank], rdtype);
310321
if (MPI_SUCCESS != err) {
@@ -328,7 +339,7 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
328339
continue;
329340
}
330341

331-
if (0 < rcounts[i]) {
342+
if (0 < rcounts[i] && 0 < rdtype_size) {
332343
++nreqs;
333344
prcv = ((char *) rbuf) + (ptrdiff_t)rdisps[i] * rext;
334345
err = MCA_PML_CALL(irecv_init(prcv, rcounts[i], rdtype,
@@ -344,7 +355,7 @@ ompi_coll_base_alltoallv_intra_basic_linear(const void *sbuf, const int *scounts
344355
continue;
345356
}
346357

347-
if (0 < scounts[i]) {
358+
if (0 < scounts[i] && 0 < sdtype_size) {
348359
++nreqs;
349360
psnd = ((char *) sbuf) + (ptrdiff_t)sdisps[i] * sext;
350361
err = MCA_PML_CALL(isend_init(psnd, scounts[i], sdtype,

ompi/mca/coll/libnbc/nbc_ialltoallv.c

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@
2323

2424
static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
2525
const void *sendbuf, const int *sendcounts,
26-
const int *sdispls, MPI_Aint sndext, MPI_Datatype sendtype,
26+
const int *sdispls, MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
2727
void *recvbuf, const int *recvcounts,
28-
const int *rdispls, MPI_Aint rcvext, MPI_Datatype recvtype);
28+
const int *rdispls, MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size);
2929

3030
static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
3131
const void *sendbuf, const int *sendcounts, const int *sdispls,
32-
MPI_Aint sndext, MPI_Datatype sendtype,
32+
MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
3333
void *recvbuf, const int *recvcounts, const int *rdispls,
34-
MPI_Aint rcvext, MPI_Datatype recvtype);
34+
MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size);
3535

3636
static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
3737
void *buf, const int *counts, const int *displs,
38-
MPI_Aint ext, MPI_Datatype type, ptrdiff_t gap);
38+
MPI_Aint ext, MPI_Datatype type, const size_t dtype_size, ptrdiff_t gap);
3939

4040
/* an alltoallv schedule can not be cached easily because the contents
4141
* of the recvcounts array may change, so a comparison of the address
@@ -104,19 +104,13 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
104104
}
105105
}
106106

107-
if (0 == sdtype_size || 0 == rdtype_size) {
108-
/* Nothing to exchange */
109-
ompi_coll_base_nbc_reserve_tags(comm, 1);
110-
return nbc_get_noop_request(persistent, request);
111-
}
112-
113107
schedule = OBJ_NEW(NBC_Schedule);
114108
if (OPAL_UNLIKELY(NULL == schedule)) {
115109
free(tmpbuf);
116110
return OMPI_ERR_OUT_OF_RESOURCE;
117111
}
118112

119-
if (!inplace && 0 < sendcounts[rank]) {
113+
if (!inplace && 0 < sendcounts[rank] && 0 < sdtype_size) {
120114
rbuf = (char *) recvbuf + rdispls[rank] * rcvext;
121115
sbuf = (char *) sendbuf + sdispls[rank] * sndext;
122116
res = NBC_Sched_copy (sbuf, false, sendcounts[rank], sendtype,
@@ -128,12 +122,12 @@ static int nbc_alltoallv_init(const void* sendbuf, const int *sendcounts, const
128122
}
129123

130124
if (inplace) {
131-
res = a2av_sched_inplace(rank, p, schedule, recvbuf, recvcounts,
132-
rdispls, rcvext, recvtype, gap);
125+
res = a2av_sched_inplace(rank, p, schedule, recvbuf, recvcounts, rdispls, rcvext, recvtype,
126+
rdtype_size, gap);
133127
} else {
134128
res = a2av_sched_linear(rank, p, schedule,
135-
sendbuf, sendcounts, sdispls, sndext, sendtype,
136-
recvbuf, recvcounts, rdispls, rcvext, recvtype);
129+
sendbuf, sendcounts, sdispls, sndext, sendtype, sdtype_size,
130+
recvbuf, recvcounts, rdispls, rcvext, recvtype, rdtype_size);
137131
}
138132
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
139133
OBJ_RELEASE(schedule);
@@ -193,11 +187,6 @@ static int nbc_alltoallv_inter_init (const void* sendbuf, const int *sendcounts,
193187

194188
ompi_datatype_type_size(sendtype, &sdtype_size);
195189
ompi_datatype_type_size(recvtype, &rdtype_size);
196-
if (0 == sdtype_size || 0 == rdtype_size) {
197-
/* Nothing to exchange */
198-
ompi_coll_base_nbc_reserve_tags(comm, 1);
199-
return nbc_get_noop_request(persistent, request);
200-
}
201190

202191
res = ompi_datatype_type_extent(sendtype, &sndext);
203192
if (MPI_SUCCESS != res) {
@@ -220,7 +209,7 @@ static int nbc_alltoallv_inter_init (const void* sendbuf, const int *sendcounts,
220209

221210
for (int i = 0; i < rsize; i++) {
222211
/* post all sends */
223-
if (0 < sendcounts[i]) {
212+
if (0 < sendcounts[i] && 0 < sdtype_size) {
224213
char *sbuf = (char *) sendbuf + sdispls[i] * sndext;
225214
res = NBC_Sched_send (sbuf, false, sendcounts[i], sendtype, i, schedule, false);
226215
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -229,7 +218,7 @@ static int nbc_alltoallv_inter_init (const void* sendbuf, const int *sendcounts,
229218
}
230219
}
231220
/* post all receives */
232-
if (0 < recvcounts[i]) {
221+
if (0 < recvcounts[i] && 0 < rdtype_size) {
233222
char *rbuf = (char *) recvbuf + rdispls[i] * rcvext;
234223
res = NBC_Sched_recv (rbuf, false, recvcounts[i], recvtype, i, schedule, false);
235224
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -278,9 +267,9 @@ int ompi_coll_libnbc_ialltoallv_inter (const void* sendbuf, const int *sendcount
278267
__opal_attribute_unused__
279268
static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
280269
const void *sendbuf, const int *sendcounts, const int *sdispls,
281-
MPI_Aint sndext, MPI_Datatype sendtype,
270+
MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
282271
void *recvbuf, const int *recvcounts, const int *rdispls,
283-
MPI_Aint rcvext, MPI_Datatype recvtype) {
272+
MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size) {
284273
int res;
285274

286275
for (int i = 0 ; i < p ; ++i) {
@@ -289,7 +278,7 @@ static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
289278
}
290279

291280
/* post send */
292-
if (0 < sendcounts[i]) {
281+
if (0 < sendcounts[i] && 0 < sdtype_size) {
293282
char *sbuf = ((char *) sendbuf) + (sdispls[i] * sndext);
294283
res = NBC_Sched_send(sbuf, false, sendcounts[i], sendtype, i, schedule, false);
295284
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -298,7 +287,7 @@ static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
298287
}
299288

300289
/* post receive */
301-
if (0 < recvcounts[i]) {
290+
if (0 < recvcounts[i] && 0 < rdtype_size) {
302291
char *rbuf = ((char *) recvbuf) + (rdispls[i] * rcvext);
303292
res = NBC_Sched_recv(rbuf, false, recvcounts[i], recvtype, i, schedule, false);
304293
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -313,17 +302,17 @@ static inline int a2av_sched_linear(int rank, int p, NBC_Schedule *schedule,
313302
__opal_attribute_unused__
314303
static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
315304
const void *sendbuf, const int *sendcounts, const int *sdispls,
316-
MPI_Aint sndext, MPI_Datatype sendtype,
305+
MPI_Aint sndext, MPI_Datatype sendtype, const size_t sdtype_size,
317306
void *recvbuf, const int *recvcounts, const int *rdispls,
318-
MPI_Aint rcvext, MPI_Datatype recvtype) {
307+
MPI_Aint rcvext, MPI_Datatype recvtype, const size_t rdtype_size) {
319308
int res;
320309

321310
for (int i = 1 ; i < p ; ++i) {
322311
int sndpeer = (rank + i) % p;
323312
int rcvpeer = (rank + p - i) %p;
324313

325314
/* post send */
326-
if (0 < sendcounts[sndpeer]) {
315+
if (0 < sendcounts[sndpeer] && 0 < sdtype_size) {
327316
char *sbuf = ((char *) sendbuf) + (sdispls[sndpeer] * sndext);
328317
res = NBC_Sched_send(sbuf, false, sendcounts[sndpeer], sendtype, sndpeer, schedule, false);
329318
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -332,7 +321,7 @@ static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
332321
}
333322

334323
/* post receive */
335-
if (0 < recvcounts[rcvpeer]) {
324+
if (0 < recvcounts[rcvpeer] && 0 < rdtype_size) {
336325
char *rbuf = ((char *) recvbuf) + (rdispls[rcvpeer] * rcvext);
337326
res = NBC_Sched_recv(rbuf, false, recvcounts[rcvpeer], recvtype, rcvpeer, schedule, true);
338327
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) {
@@ -346,7 +335,7 @@ static inline int a2av_sched_pairwise(int rank, int p, NBC_Schedule *schedule,
346335

347336
static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
348337
void *buf, const int *counts, const int *displs,
349-
MPI_Aint ext, MPI_Datatype type, ptrdiff_t gap) {
338+
MPI_Aint ext, MPI_Datatype type, const size_t dtype_size, ptrdiff_t gap) {
350339
int res;
351340

352341
for (int i = 1; i < (p+1)/2; i++) {
@@ -355,6 +344,11 @@ static inline int a2av_sched_inplace(int rank, int p, NBC_Schedule *schedule,
355344
char *sbuf = (char *) buf + displs[speer] * ext;
356345
char *rbuf = (char *) buf + displs[rpeer] * ext;
357346

347+
if (0 == dtype_size) {
348+
/* Nothing to exchange */
349+
return OMPI_SUCCESS;
350+
}
351+
358352
if (0 < counts[rpeer]) {
359353
res = NBC_Sched_copy (rbuf, false, counts[rpeer], type,
360354
(void *)(-gap), true, counts[rpeer], type,

0 commit comments

Comments
 (0)