Skip to content

Commit 1d54071

Browse files
author
Valentin Petrov
committed
coll/hcoll: reduce_scatter(block) interface
Signed-off-by: Valentin Petrov <valentinp@mellanox.com>
1 parent 868eee3 commit 1d54071

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

ompi/mca/coll/hcoll/coll_hcoll.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ struct mca_coll_hcoll_module_t {
141141
mca_coll_base_module_t *previous_scatterv_module;
142142
mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter;
143143
mca_coll_base_module_t *previous_reduce_scatter_module;
144+
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
145+
mca_coll_base_module_t *previous_reduce_scatter_block_module;
144146
mca_coll_base_module_ibcast_fn_t previous_ibcast;
145147
mca_coll_base_module_t *previous_ibcast_module;
146148
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
@@ -211,6 +213,18 @@ int mca_coll_hcoll_allreduce(const void *sbuf, void *rbuf, int count,
211213
struct ompi_communicator_t *comm,
212214
mca_coll_base_module_t *module);
213215

216+
#if HCOLL_API > HCOLL_VERSION(4,5)
217+
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
218+
struct ompi_datatype_t *dtype,
219+
struct ompi_op_t *op,
220+
struct ompi_communicator_t *comm,
221+
mca_coll_base_module_t *module);
222+
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
223+
struct ompi_datatype_t *dtype,
224+
struct ompi_op_t *op,
225+
struct ompi_communicator_t *comm,
226+
mca_coll_base_module_t *module);
227+
#endif
214228
int mca_coll_hcoll_reduce(const void *sbuf, void *rbuf, int count,
215229
struct ompi_datatype_t *dtype,
216230
struct ompi_op_t *op,

ompi/mca/coll/hcoll/coll_hcoll_module.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ static void mca_coll_hcoll_module_clear(mca_coll_hcoll_module_t *hcoll_module)
5151
hcoll_module->previous_alltoallw = NULL;
5252
hcoll_module->previous_reduce = NULL;
5353
hcoll_module->previous_reduce_scatter = NULL;
54+
hcoll_module->previous_reduce_scatter_block = NULL;
5455
hcoll_module->previous_ibarrier = NULL;
5556
hcoll_module->previous_ibcast = NULL;
5657
hcoll_module->previous_iallreduce = NULL;
@@ -119,6 +120,8 @@ static void mca_coll_hcoll_module_destruct(mca_coll_hcoll_module_t *hcoll_module
119120
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_barrier_module);
120121
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_bcast_module);
121122
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allreduce_module);
123+
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_block_module);
124+
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_reduce_scatter_module);
122125
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgather_module);
123126
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_allgatherv_module);
124127
OBJ_RELEASE_IF_NOT_NULL(hcoll_module->previous_gatherv_module);
@@ -173,6 +176,8 @@ static int mca_coll_hcoll_save_coll_handlers(mca_coll_hcoll_module_t *hcoll_modu
173176
HCOL_SAVE_PREV_COLL_API(barrier);
174177
HCOL_SAVE_PREV_COLL_API(bcast);
175178
HCOL_SAVE_PREV_COLL_API(allreduce);
179+
HCOL_SAVE_PREV_COLL_API(reduce_scatter_block);
180+
HCOL_SAVE_PREV_COLL_API(reduce_scatter);
176181
HCOL_SAVE_PREV_COLL_API(reduce);
177182
HCOL_SAVE_PREV_COLL_API(allgather);
178183
HCOL_SAVE_PREV_COLL_API(allgatherv);
@@ -419,6 +424,12 @@ mca_coll_hcoll_comm_query(struct ompi_communicator_t *comm, int *priority)
419424
hcoll_module->super.coll_ialltoallv = hcoll_collectives.coll_ialltoallv ? mca_coll_hcoll_ialltoallv : NULL;
420425
#else
421426
hcoll_module->super.coll_ialltoallv = NULL;
427+
#endif
428+
#if HCOLL_API > HCOLL_VERSION(4,5)
429+
hcoll_module->super.coll_reduce_scatter_block = hcoll_collectives.coll_reduce_scatter_block ?
430+
mca_coll_hcoll_reduce_scatter_block : NULL;
431+
hcoll_module->super.coll_reduce_scatter = hcoll_collectives.coll_reduce_scatter ?
432+
mca_coll_hcoll_reduce_scatter : NULL;
422433
#endif
423434
*priority = cm->hcoll_priority;
424435
module = &hcoll_module->super;

ompi/mca/coll/hcoll/coll_hcoll_ops.c

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,3 +760,88 @@ int mca_coll_hcoll_ialltoallv(const void *sbuf, int *scounts, int *sdisps,
760760
return rc;
761761
}
762762
#endif
763+
764+
#if HCOLL_API > HCOLL_VERSION(4,5)
765+
int mca_coll_hcoll_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
766+
struct ompi_datatype_t *dtype,
767+
struct ompi_op_t *op,
768+
struct ompi_communicator_t *comm,
769+
mca_coll_base_module_t *module) {
770+
dte_data_representation_t Dtype;
771+
hcoll_dte_op_t *Op;
772+
int rc;
773+
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER BLOCK");
774+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
775+
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
776+
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
777+
/*If we are here then datatype is not simple predefined datatype */
778+
/*In future we need to add more complex mapping to the dte_data_representation_t */
779+
/* Now use fallback */
780+
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
781+
dtype->super.name);
782+
goto fallback;
783+
}
784+
785+
Op = ompi_op_2_hcolrte_op(op);
786+
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
787+
/*If we are here then datatype is not simple predefined datatype */
788+
/*In future we need to add more complex mapping to the dte_data_representation_t */
789+
/* Now use fallback */
790+
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
791+
op->o_name);
792+
goto fallback;
793+
}
794+
795+
rc = hcoll_collectives.coll_reduce_scatter_block((void *)sbuf,rbuf,rcount,Dtype,Op,hcoll_module->hcoll_context);
796+
if (HCOLL_SUCCESS != rc){
797+
fallback:
798+
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
799+
rc = hcoll_module->previous_reduce_scatter_block(sbuf,rbuf,
800+
rcount,dtype,op,
801+
comm, hcoll_module->previous_allreduce_module);
802+
}
803+
return rc;
804+
}
805+
806+
int mca_coll_hcoll_reduce_scatter(const void *sbuf, void *rbuf, const int* rcounts,
807+
struct ompi_datatype_t *dtype,
808+
struct ompi_op_t *op,
809+
struct ompi_communicator_t *comm,
810+
mca_coll_base_module_t *module) {
811+
dte_data_representation_t Dtype;
812+
hcoll_dte_op_t *Op;
813+
int rc;
814+
HCOL_VERBOSE(20,"RUNNING HCOL REDUCE SCATTER");
815+
mca_coll_hcoll_module_t *hcoll_module = (mca_coll_hcoll_module_t*)module;
816+
Dtype = ompi_dtype_2_hcoll_dtype(dtype, NO_DERIVED);
817+
if (OPAL_UNLIKELY(HCOL_DTE_IS_ZERO(Dtype))){
818+
/*If we are here then datatype is not simple predefined datatype */
819+
/*In future we need to add more complex mapping to the dte_data_representation_t */
820+
/* Now use fallback */
821+
HCOL_VERBOSE(20,"Ompi_datatype is not supported: dtype = %s; calling fallback allreduce;",
822+
dtype->super.name);
823+
goto fallback;
824+
}
825+
826+
Op = ompi_op_2_hcolrte_op(op);
827+
if (OPAL_UNLIKELY(HCOL_DTE_OP_NULL == Op->id)){
828+
/*If we are here then datatype is not simple predefined datatype */
829+
/*In future we need to add more complex mapping to the dte_data_representation_t */
830+
/* Now use fallback */
831+
HCOL_VERBOSE(20,"ompi_op_t is not supported: op = %s; calling fallback allreduce;",
832+
op->o_name);
833+
goto fallback;
834+
}
835+
836+
rc = hcoll_collectives.coll_reduce_scatter((void*)sbuf, rbuf, (int*)rcounts,
837+
Dtype, Op, hcoll_module->hcoll_context);
838+
if (HCOLL_SUCCESS != rc){
839+
fallback:
840+
HCOL_VERBOSE(20,"RUNNING FALLBACK ALLREDUCE");
841+
rc = hcoll_module->previous_reduce_scatter(sbuf,rbuf,
842+
rcounts,dtype,op,
843+
comm, hcoll_module->previous_allreduce_module);
844+
}
845+
return rc;
846+
}
847+
#endif

0 commit comments

Comments
 (0)