Skip to content

Commit 2469f6c

Browse files
author
valentin petrov
authored
Merge pull request #6708 from vspetrov/master
Coll/hcoll: adding scatterv interface
2 parents 88503f0 + 6ea920e commit 2469f6c

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

ompi/mca/coll/hcoll/coll_hcoll.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ struct mca_coll_hcoll_module_t {
138138
mca_coll_base_module_t *previous_gather_module;
139139
mca_coll_base_module_gatherv_fn_t previous_gatherv;
140140
mca_coll_base_module_t *previous_gatherv_module;
141+
mca_coll_base_module_scatterv_fn_t previous_scatterv;
142+
mca_coll_base_module_t *previous_scatterv_module;
141143
mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter;
142144
mca_coll_base_module_t *previous_reduce_scatter_module;
143145
mca_coll_base_module_ibcast_fn_t previous_ibcast;
@@ -241,6 +243,15 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount,
241243
struct ompi_communicator_t *comm,
242244
mca_coll_base_module_t *module);
243245

246+
247+
int mca_coll_hcoll_scatterv(const void* sbuf, const int *scounts, const int *displs,
248+
struct ompi_datatype_t *sdtype,
249+
void* rbuf, int rcount,
250+
struct ompi_datatype_t *rdtype,
251+
int root,
252+
struct ompi_communicator_t *comm,
253+
mca_coll_base_module_t *module);
254+
244255
int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm,
245256
ompi_request_t** request,
246257
mca_coll_base_module_t *module);

ompi/mca/coll/hcoll/coll_hcoll_module.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module)
4545
hcoll_module->previous_allgatherv = NULL;
4646
hcoll_module->previous_gather = NULL;
4747
hcoll_module->previous_gatherv = NULL;
48+
hcoll_module->previous_scatterv = NULL;
4849
hcoll_module->previous_alltoall = NULL;
4950
hcoll_module->previous_alltoallv = NULL;
5051
hcoll_module->previous_alltoallw = NULL;
@@ -68,6 +69,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module)
6869
hcoll_module->previous_allgatherv_module = NULL;
6970
hcoll_module->previous_gather_module = NULL;
7071
hcoll_module->previous_gatherv_module = NULL;
72+
hcoll_module->previous_scatterv_module = NULL;
7173
hcoll_module->previous_alltoall_module = NULL;
7274
hcoll_module->previous_alltoallv_module = NULL;
7375
hcoll_module->previous_alltoallw_module = NULL;
@@ -120,6 +122,7 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module
120122
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module);
121123
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module);
122124
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module);
125+
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_scatterv_module);
123126
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoall_module);
124127
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_alltoallv_module);
125128
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_module);
@@ -174,6 +177,7 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu
174177
HCOL_SAVE_PREV_COLL_API(allgather);
175178
HCOL_SAVE_PREV_COLL_API(allgatherv);
176179
HCOL_SAVE_PREV_COLL_API(gatherv);
180+
HCOL_SAVE_PREV_COLL_API(scatterv);
177181
HCOL_SAVE_PREV_COLL_API(alltoall);
178182
HCOL_SAVE_PREV_COLL_API(alltoallv);
179183

@@ -392,6 +396,7 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
392396
hcoll_module->super.coll_alltoall = hcoll_collectives.coll_alltoall ? mca_coll_hcoll_alltoall : NULL;
393397
hcoll_module->super.coll_alltoallv = hcoll_collectives.coll_alltoallv ? mca_coll_hcoll_alltoallv : NULL;
394398
hcoll_module->super.coll_gatherv = hcoll_collectives.coll_gatherv ? mca_coll_hcoll_gatherv : NULL;
399+
hcoll_module->super.coll_scatterv = hcoll_collectives.coll_scatterv ? mca_coll_hcoll_scatterv : NULL;
395400
hcoll_module->super.coll_reduce = hcoll_collectives.coll_reduce ? mca_coll_hcoll_reduce : NULL;
396401
hcoll_module->super.coll_ibarrier = hcoll_collectives.coll_ibarrier ? mca_coll_hcoll_ibarrier : NULL;
397402
hcoll_module->super.coll_ibcast = hcoll_collectives.coll_ibcast ? mca_coll_hcoll_ibcast : NULL;

ompi/mca/coll/hcoll/coll_hcoll_ops.c

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,43 @@ int mca_coll_hcoll_gatherv(const void* sbuf, int scount,
397397

398398
}
399399

400+
int mca_coll_hcoll_scatterv(const void* sbuf, const int *scounts, const int *displs,
401+
struct ompi_datatype_t *sdtype,
402+
void* rbuf, int rcount,
403+
struct ompi_datatype_t *rdtype,
404+
int root,
405+
struct ompi_communicator_t *comm,
406+
mca_coll_base_module_t *module)
407+
{
408+
dte_data_representation_t stype;
409+
dte_data_representation_t rtype;
410+
int rc;
411+
HCOL_VERBOSE(20,"RUNNING HCOL SCATTERV");
412+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
413+
stype = ompi_dtype_2_hcoll_dtype(sdtype, NO_DERIVED);
414+
rtype = ompi_dtype_2_hcoll_dtype(rdtype, NO_DERIVED);
415+
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(stype) || HCOL_DTE_IS_ZERO(rtype))) {
416+
/*If we are here then datatype is not simple predefined datatype */
417+
/*In future we need to add more complex mapping to the dte_data_representation_t */
418+
/* Now use fallback */
419+
HCOL_VERBOSE(20,"Ompi_datatype is not supported: sdtype = %s, rdtype = %s; calling fallback scatterv;",
420+
sdtype->super.name,
421+
rdtype->super.name);
422+
rc = hcoll_module->previous_scatterv(sbuf, scounts, displs, sdtype,
423+
rbuf, rcount, rdtype, root,
424+
comm, hcoll_module->previous_scatterv_module);
425+
return rc;
426+
}
427+
rc = hcoll_collectives.coll_scatterv((void *)sbuf, (int *)scounts, (int *)displs, stype, rbuf, rcount, rtype, root, hcoll_module->hcoll_context);
428+
if (HCOLL_SUCCESS != rc){
429+
HCOL_VERBOSE(20,"RUNNING FALLBACK SCATTERV");
430+
rc = hcoll_module->previous_scatterv(sbuf, scounts, displs, sdtype,
431+
rbuf, rcount, rdtype, root,
432+
comm, hcoll_module->previous_scatterv_module);
433+
}
434+
return rc;
435+
}
436+
400437
int mca_coll_hcoll_ibarrier(struct ompi_communicator_t *comm,
401438
ompi_request_t ** request,
402439
mca_coll_base_module_t *module)

0 commit comments

Comments
 (0)