Skip to content

Commit 29bbf0b

Browse files
author
Valentin Petrov
committed
coll/hcoll: fixes fallback on rooted ops
Signed-off-by: Valentin Petrov <valentinp@nvidia.com> (cherry picked from commit 1055146)
1 parent b58ad5e commit 29bbf0b

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

ompi/mca/coll/hcoll/coll_hcoll_ops.c

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,19 @@ int mca_coll_hcoll_gather(const void *sbuf, int scount,
156156
struct ompi_datatype_t *rdtype,
157157
int root,
158158
struct ompi_communicator_t *comm,
159-
mca_coll_base_module_t *module){
159+
mca_coll_base_module_t *module)
160+
{
161+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
160162
dte_data_representation_t stype;
161163
dte_data_representation_t rtype;
162164
int rc;
165+
163166
HCOL_VERBOSE(20,"RUNNING HCOL GATHER");
164-
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
167+
168+
if (root != comm->c_my_rank) {
169+
rdtype = sdtype;
170+
}
171+
165172
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
166173
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
167174
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
@@ -368,13 +375,19 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount,
368375
struct ompi_communicator_t *comm,
369376
mca_coll_base_module_t *module)
370377
{
378+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
371379
dte_data_representation_t stype;
372380
dte_data_representation_t rtype;
373381
int rc;
374382
HCOL_VERBOSE(20,"RUNNING HCOL GATHERV");
375-
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
383+
384+
if (root != comm->c_my_rank) {
385+
rdtype = sdtype;
386+
}
387+
376388
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
377389
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
390+
378391
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
379392
/*If we are here then datatype is not simple predefined datatype */
380393
/*In future we need to add more complex mapping to the dte_data_representation_t */
@@ -387,7 +400,9 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount,
387400
comm, hcoll_module->previous_gatherv_module);
388401
return rc;
389402
}
390-
rc = hcoll_collectives.coll_gatherv((void *)sbuf, scount, stype, rbuf, (int *)rcounts, (int *)displs, rtype, root, hcoll_module->hcoll_context);
403+
rc = hcoll_collectives.coll_gatherv((void *)sbuf, scount, stype, rbuf,
404+
(int *)rcounts, (int *)displs, rtype,
405+
root, hcoll_module->hcoll_context);
391406
if (HCOLL_SUCCESS != rc){
392407
HCOL_VERBOSE(20,"RUNNING FALLBACK GATHERV");
393408
rc = hcoll_module->previous_gatherv(sbuf,scount,sdtype,
@@ -406,13 +421,20 @@ int mca_coll_hcoll_scatterv(const void* sbuf, const int *scounts, const int *dis
406421
struct ompi_communicator_t *comm,
407422
mca_coll_base_module_t *module)
408423
{
424+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
409425
dte_data_representation_t stype;
410426
dte_data_representation_t rtype;
411427
int rc;
428+
412429
HCOL_VERBOSE(20,"RUNNING HCOL SCATTERV");
413-
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
430+
431+
if (root != comm->c_my_rank) {
432+
sdtype = rdtype;
433+
}
434+
414435
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
415436
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
437+
416438
if (rbuf == MPI_IN_PLACE) {
417439
assert(root == comm->c_my_rank);
418440
rtype = stype;
@@ -693,13 +715,20 @@ int mca_coll_hcoll_igatherv(const void* sbuf, int scount,
693715
ompi_request_t ** request,
694716
mca_coll_base_module_t *module)
695717
{
718+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
696719
dte_data_representation_t stype;
697720
dte_data_representation_t rtype;
698721
int rc;
699722
void** rt_handle;
723+
700724
HCOL_VERBOSE(20,"RUNNING HCOL IGATHERV");
701-
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
725+
702726
rt_handle = (void**) request;
727+
728+
if (root != comm->c_my_rank) {
729+
rdtype = sdtype;
730+
}
731+
703732
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
704733
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
705734
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {

0 commit comments

Comments
 (0)