Skip to content

Commit c056408

Browse files
authored
Merge pull request #11488 from wzamazon/coll_fix_in_place
coll/han,base: correctly handle MPI_IN_PLACE
2 parents 62c0738 + d9d6398 commit c056408

File tree

1 file changed

+59
-13
lines changed

1 file changed

+59
-13
lines changed

ompi/mca/coll/han/coll_han_allreduce.c

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ mca_coll_han_allreduce_intra(const void *sbuf,
184184
mca_coll_task_t *t_next_seg = OBJ_NEW(mca_coll_task_t);
185185
/* Setup up t_next_seg task arguments */
186186
t->cur_task = t_next_seg;
187-
t->sbuf = (char *) t->sbuf + extent * t->seg_count;
187+
t->sbuf = (t->sbuf == MPI_IN_PLACE) ? MPI_IN_PLACE : (char *) t->sbuf + extent * t->seg_count;
188188
t->rbuf = (char *) t->rbuf + extent * t->seg_count;
189189
t->cur_seg = t->cur_seg + 1;
190190
/* Init t_next_seg task */
@@ -262,11 +262,26 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
262262
if (t->cur_seg == t->num_segments - 2 && t->last_seg_count != t->seg_count) {
263263
tmp_count = t->last_seg_count;
264264
}
265-
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + extent * t->seg_count,
266-
(char *) t->rbuf + extent * t->seg_count, tmp_count,
267-
t->dtype, t->op, t->root_low_rank, t->low_comm,
268-
t->low_comm->c_coll->coll_reduce_module);
269265

266+
if (t->sbuf == MPI_IN_PLACE) {
267+
if (!t->noop) {
268+
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
269+
(char *) t->rbuf + extent * t->seg_count, tmp_count,
270+
t->dtype, t->op, t->root_low_rank, t->low_comm,
271+
t->low_comm->c_coll->coll_reduce_module);
272+
} else {
273+
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + extent * t->seg_count,
274+
NULL, tmp_count,
275+
t->dtype, t->op, t->root_low_rank, t->low_comm,
276+
t->low_comm->c_coll->coll_reduce_module);
277+
278+
}
279+
} else {
280+
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + extent * t->seg_count,
281+
(char *) t->rbuf + extent * t->seg_count, tmp_count,
282+
t->dtype, t->op, t->root_low_rank, t->low_comm,
283+
t->low_comm->c_coll->coll_reduce_module);
284+
}
270285
}
271286
if (!t->noop) {
272287
ompi_request_wait(&ireduce_req, MPI_STATUS_IGNORE);
@@ -321,10 +336,26 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
321336
if (t->cur_seg == t->num_segments - 3 && t->last_seg_count != t->seg_count) {
322337
tmp_count = t->last_seg_count;
323338
}
324-
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
325-
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
326-
t->dtype, t->op, t->root_low_rank, t->low_comm,
327-
t->low_comm->c_coll->coll_reduce_module);
339+
340+
if (t->sbuf == MPI_IN_PLACE) {
341+
if (!t->noop) {
342+
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
343+
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
344+
t->dtype, t->op, t->root_low_rank, t->low_comm,
345+
t->low_comm->c_coll->coll_reduce_module);
346+
} else {
347+
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 2 * extent * t->seg_count,
348+
NULL, tmp_count,
349+
t->dtype, t->op, t->root_low_rank, t->low_comm,
350+
t->low_comm->c_coll->coll_reduce_module);
351+
352+
}
353+
} else {
354+
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 2 * extent * t->seg_count,
355+
(char *) t->rbuf + 2 * extent * t->seg_count, tmp_count,
356+
t->dtype, t->op, t->root_low_rank, t->low_comm,
357+
t->low_comm->c_coll->coll_reduce_module);
358+
}
328359
}
329360
if (!t->noop && req_count > 0) {
330361
ompi_request_wait_all(req_count, reqs, MPI_STATUSES_IGNORE);
@@ -385,10 +416,25 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
385416
if (t->cur_seg == t->num_segments - 4 && t->last_seg_count != t->seg_count) {
386417
tmp_count = t->last_seg_count;
387418
}
388-
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 3 * extent * t->seg_count,
389-
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
390-
t->dtype, t->op, t->root_low_rank, t->low_comm,
391-
t->low_comm->c_coll->coll_reduce_module);
419+
420+
if (t->sbuf == MPI_IN_PLACE) {
421+
if (!t->noop) {
422+
t->low_comm->c_coll->coll_reduce(MPI_IN_PLACE,
423+
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
424+
t->dtype, t->op, t->root_low_rank, t->low_comm,
425+
t->low_comm->c_coll->coll_reduce_module);
426+
} else {
427+
t->low_comm->c_coll->coll_reduce((char *) t->rbuf + 3 * extent * t->seg_count,
428+
NULL, tmp_count,
429+
t->dtype, t->op, t->root_low_rank, t->low_comm,
430+
t->low_comm->c_coll->coll_reduce_module);
431+
}
432+
} else {
433+
t->low_comm->c_coll->coll_reduce((char *) t->sbuf + 3 * extent * t->seg_count,
434+
(char *) t->rbuf + 3 * extent * t->seg_count, tmp_count,
435+
t->dtype, t->op, t->root_low_rank, t->low_comm,
436+
t->low_comm->c_coll->coll_reduce_module);
437+
}
392438
}
393439
/* lb of cur_seg */
394440
if (t->cur_seg == t->num_segments - 1 && t->last_seg_count != t->seg_count) {

0 commit comments

Comments
 (0)