Skip to content

Commit 0f54160

Browse files
authored
Merge pull request #11091 from Sergei-Lebedev/coll_ucc_gather_scatter_rs
coll/ucc: add support for gather(v), scatterv, reduce_scatter
2 parents 72e5e0d + f29301a commit 0f54160

10 files changed

+1011
-98
lines changed

ompi/mca/coll/ucc/Makefile.am

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
#
44
# Copyright (c) 2021 Mellanox Technologies. All rights reserved.
5+
# Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
56
# $COPYRIGHT$
67
#
78
# Additional copyrights may follow
@@ -12,21 +13,27 @@
1213

1314
AM_CPPFLAGS = $(coll_ucc_CPPFLAGS)
1415

15-
coll_ucc_sources = \
16-
coll_ucc.h \
17-
coll_ucc_debug.h \
18-
coll_ucc_dtypes.h \
19-
coll_ucc_common.h \
20-
coll_ucc_module.c \
21-
coll_ucc_component.c \
22-
coll_ucc_barrier.c \
23-
coll_ucc_bcast.c \
24-
coll_ucc_allreduce.c \
25-
coll_ucc_reduce.c \
26-
coll_ucc_alltoall.c \
27-
coll_ucc_alltoallv.c \
28-
coll_ucc_allgather.c \
29-
coll_ucc_allgatherv.c
16+
coll_ucc_sources = \
17+
coll_ucc.h \
18+
coll_ucc_debug.h \
19+
coll_ucc_dtypes.h \
20+
coll_ucc_common.h \
21+
coll_ucc_module.c \
22+
coll_ucc_component.c \
23+
coll_ucc_allgather.c \
24+
coll_ucc_allgatherv.c \
25+
coll_ucc_allreduce.c \
26+
coll_ucc_alltoall.c \
27+
coll_ucc_alltoallv.c \
28+
coll_ucc_barrier.c \
29+
coll_ucc_bcast.c \
30+
coll_ucc_gather.c \
31+
coll_ucc_gatherv.c \
32+
coll_ucc_reduce.c \
33+
coll_ucc_reduce_scatter_block.c \
34+
coll_ucc_reduce_scatter.c \
35+
coll_ucc_scatter.c \
36+
coll_ucc_scatterv.c
3037

3138
# Make the output library in this directory, and name it either
3239
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 152 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/**
22
Copyright (c) 2021 Mellanox Technologies. All rights reserved.
3+
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
34
$COPYRIGHT$
45
56
Additional copyrights may follow
@@ -24,13 +25,20 @@
2425

2526
BEGIN_C_DECLS
2627

27-
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
28-
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
29-
UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
30-
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV)
28+
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
29+
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
30+
UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
31+
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \
32+
UCC_COLL_TYPE_GATHER | UCC_COLL_TYPE_GATHERV | \
33+
UCC_COLL_TYPE_REDUCE_SCATTER | UCC_COLL_TYPE_REDUCE_SCATTERV | \
34+
UCC_COLL_TYPE_SCATTERV | UCC_COLL_TYPE_SCATTER)
3135

32-
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \
33-
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce"
36+
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather," \
37+
"allgatherv,reduce,gather,gatherv,reduce_scatter_block,"\
38+
"reduce_scatter,scatterv,scatter," \
39+
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,"\
40+
"iallgatherv,ireduce,igather,igatherv,ireduce_scatter_block,"\
41+
"ireduce_scatter,iscatterv,iscatter"
3442

3543
typedef struct mca_coll_ucc_req {
3644
ompi_request_t super;
@@ -64,42 +72,66 @@ OMPI_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
6472
* UCC enabled communicator
6573
*/
6674
struct mca_coll_ucc_module_t {
67-
mca_coll_base_module_t super;
68-
ompi_communicator_t* comm;
69-
int rank;
70-
ucc_team_h ucc_team;
71-
mca_coll_base_module_allreduce_fn_t previous_allreduce;
72-
mca_coll_base_module_t* previous_allreduce_module;
73-
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
74-
mca_coll_base_module_t* previous_iallreduce_module;
75-
mca_coll_base_module_reduce_fn_t previous_reduce;
76-
mca_coll_base_module_t* previous_reduce_module;
77-
mca_coll_base_module_ireduce_fn_t previous_ireduce;
78-
mca_coll_base_module_t* previous_ireduce_module;
79-
mca_coll_base_module_barrier_fn_t previous_barrier;
80-
mca_coll_base_module_t* previous_barrier_module;
81-
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
82-
mca_coll_base_module_t* previous_ibarrier_module;
83-
mca_coll_base_module_bcast_fn_t previous_bcast;
84-
mca_coll_base_module_t* previous_bcast_module;
85-
mca_coll_base_module_ibcast_fn_t previous_ibcast;
86-
mca_coll_base_module_t* previous_ibcast_module;
87-
mca_coll_base_module_alltoall_fn_t previous_alltoall;
88-
mca_coll_base_module_t* previous_alltoall_module;
89-
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
90-
mca_coll_base_module_t* previous_ialltoall_module;
91-
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
92-
mca_coll_base_module_t* previous_alltoallv_module;
93-
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
94-
mca_coll_base_module_t* previous_ialltoallv_module;
95-
mca_coll_base_module_allgather_fn_t previous_allgather;
96-
mca_coll_base_module_t* previous_allgather_module;
97-
mca_coll_base_module_iallgather_fn_t previous_iallgather;
98-
mca_coll_base_module_t* previous_iallgather_module;
99-
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
100-
mca_coll_base_module_t* previous_allgatherv_module;
101-
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
102-
mca_coll_base_module_t* previous_iallgatherv_module;
75+
mca_coll_base_module_t super;
76+
ompi_communicator_t* comm;
77+
int rank;
78+
ucc_team_h ucc_team;
79+
mca_coll_base_module_allreduce_fn_t previous_allreduce;
80+
mca_coll_base_module_t* previous_allreduce_module;
81+
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
82+
mca_coll_base_module_t* previous_iallreduce_module;
83+
mca_coll_base_module_reduce_fn_t previous_reduce;
84+
mca_coll_base_module_t* previous_reduce_module;
85+
mca_coll_base_module_ireduce_fn_t previous_ireduce;
86+
mca_coll_base_module_t* previous_ireduce_module;
87+
mca_coll_base_module_barrier_fn_t previous_barrier;
88+
mca_coll_base_module_t* previous_barrier_module;
89+
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
90+
mca_coll_base_module_t* previous_ibarrier_module;
91+
mca_coll_base_module_bcast_fn_t previous_bcast;
92+
mca_coll_base_module_t* previous_bcast_module;
93+
mca_coll_base_module_ibcast_fn_t previous_ibcast;
94+
mca_coll_base_module_t* previous_ibcast_module;
95+
mca_coll_base_module_alltoall_fn_t previous_alltoall;
96+
mca_coll_base_module_t* previous_alltoall_module;
97+
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
98+
mca_coll_base_module_t* previous_ialltoall_module;
99+
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
100+
mca_coll_base_module_t* previous_alltoallv_module;
101+
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
102+
mca_coll_base_module_t* previous_ialltoallv_module;
103+
mca_coll_base_module_allgather_fn_t previous_allgather;
104+
mca_coll_base_module_t* previous_allgather_module;
105+
mca_coll_base_module_iallgather_fn_t previous_iallgather;
106+
mca_coll_base_module_t* previous_iallgather_module;
107+
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
108+
mca_coll_base_module_t* previous_allgatherv_module;
109+
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
110+
mca_coll_base_module_t* previous_iallgatherv_module;
111+
mca_coll_base_module_gather_fn_t previous_gather;
112+
mca_coll_base_module_t* previous_gather_module;
113+
mca_coll_base_module_igather_fn_t previous_igather;
114+
mca_coll_base_module_t* previous_igather_module;
115+
mca_coll_base_module_gatherv_fn_t previous_gatherv;
116+
mca_coll_base_module_t* previous_gatherv_module;
117+
mca_coll_base_module_igatherv_fn_t previous_igatherv;
118+
mca_coll_base_module_t* previous_igatherv_module;
119+
mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block;
120+
mca_coll_base_module_t* previous_reduce_scatter_block_module;
121+
mca_coll_base_module_ireduce_scatter_block_fn_t previous_ireduce_scatter_block;
122+
mca_coll_base_module_t* previous_ireduce_scatter_block_module;
123+
mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter;
124+
mca_coll_base_module_t* previous_reduce_scatter_module;
125+
mca_coll_base_module_ireduce_scatter_fn_t previous_ireduce_scatter;
126+
mca_coll_base_module_t* previous_ireduce_scatter_module;
127+
mca_coll_base_module_scatterv_fn_t previous_scatterv;
128+
mca_coll_base_module_t* previous_scatterv_module;
129+
mca_coll_base_module_iscatterv_fn_t previous_iscatterv;
130+
mca_coll_base_module_t* previous_iscatterv_module;
131+
mca_coll_base_module_scatter_fn_t previous_scatter;
132+
mca_coll_base_module_t* previous_scatter_module;
133+
mca_coll_base_module_iscatter_fn_t previous_iscatter;
134+
mca_coll_base_module_t* previous_iscatter_module;
103135
};
104136
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
105137
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
@@ -195,5 +227,83 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_
195227
ompi_request_t** request,
196228
mca_coll_base_module_t *module);
197229

230+
int mca_coll_ucc_gather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
231+
void *rbuf, int rcount, struct ompi_datatype_t *rdtype,
232+
int root, struct ompi_communicator_t *comm,
233+
mca_coll_base_module_t *module);
234+
235+
int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
236+
void *rbuf, int rcount, struct ompi_datatype_t *rdtype,
237+
int root, struct ompi_communicator_t *comm,
238+
ompi_request_t** request,
239+
mca_coll_base_module_t *module);
240+
241+
int mca_coll_ucc_gatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
242+
void *rbuf, const int *rcounts, const int *disps,
243+
struct ompi_datatype_t *rdtype, int root,
244+
struct ompi_communicator_t *comm,
245+
mca_coll_base_module_t *module);
246+
247+
int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
248+
void *rbuf, const int *rcounts, const int *disps,
249+
struct ompi_datatype_t *rdtype, int root,
250+
struct ompi_communicator_t *comm,
251+
ompi_request_t** request,
252+
mca_coll_base_module_t *module);
253+
254+
int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
255+
struct ompi_datatype_t *dtype,
256+
struct ompi_op_t *op,
257+
struct ompi_communicator_t *comm,
258+
mca_coll_base_module_t *module);
259+
260+
int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount,
261+
struct ompi_datatype_t *dtype,
262+
struct ompi_op_t *op,
263+
struct ompi_communicator_t *comm,
264+
ompi_request_t** request,
265+
mca_coll_base_module_t *module);
266+
267+
int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
268+
struct ompi_datatype_t *dtype,
269+
struct ompi_op_t *op,
270+
struct ompi_communicator_t *comm,
271+
mca_coll_base_module_t *module);
272+
273+
int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
274+
struct ompi_datatype_t *dtype,
275+
struct ompi_op_t *op,
276+
struct ompi_communicator_t *comm,
277+
ompi_request_t** request,
278+
mca_coll_base_module_t *module);
279+
280+
int mca_coll_ucc_scatterv(const void *sbuf, const int *scounts,
281+
const int *disps, struct ompi_datatype_t *sdtype,
282+
void *rbuf, int rcount,
283+
struct ompi_datatype_t *rdtype, int root,
284+
struct ompi_communicator_t *comm,
285+
mca_coll_base_module_t *module);
286+
287+
int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts,
288+
const int *disps, struct ompi_datatype_t *sdtype,
289+
void *rbuf, int rcount,
290+
struct ompi_datatype_t *rdtype, int root,
291+
struct ompi_communicator_t *comm,
292+
ompi_request_t** request,
293+
mca_coll_base_module_t *module);
294+
295+
int mca_coll_ucc_scatter(const void *sbuf, int scount,
296+
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
297+
struct ompi_datatype_t *rdtype, int root,
298+
struct ompi_communicator_t *comm,
299+
mca_coll_base_module_t *module);
300+
301+
int mca_coll_ucc_iscatter(const void *sbuf, int scount,
302+
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
303+
struct ompi_datatype_t *rdtype, int root,
304+
struct ompi_communicator_t *comm,
305+
ompi_request_t** request,
306+
mca_coll_base_module_t *module);
307+
198308
END_C_DECLS
199309
#endif

ompi/mca/coll/ucc/coll_ucc_component.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */
22
/*
33
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
45
* $COPYRIGHT$
56
*
67
* Additional copyrights may follow
@@ -120,6 +121,18 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str)
120121
return UCC_COLL_TYPE_ALLGATHERV;
121122
} else if (0 == strcasecmp(str, "reduce")) {
122123
return UCC_COLL_TYPE_REDUCE;
124+
} else if (0 == strcasecmp(str, "gather")) {
125+
return UCC_COLL_TYPE_GATHER;
126+
} else if (0 == strcasecmp(str, "gatherv")) {
127+
return UCC_COLL_TYPE_GATHERV;
128+
} else if (0 == strcasecmp(str, "reduce_scatter_block")) {
129+
return UCC_COLL_TYPE_REDUCE_SCATTER;
130+
} else if (0 == strcasecmp(str, "reduce_scatter")) {
131+
return UCC_COLL_TYPE_REDUCE_SCATTERV;
132+
} else if (0 == strcasecmp(str, "scatterv")) {
133+
return UCC_COLL_TYPE_SCATTERV;
134+
} else if (0 == strcasecmp(str, "scatter")) {
135+
return UCC_COLL_TYPE_SCATTER;
123136
}
124137
UCC_ERROR("incorrect value for cts: %s, allowed: %s",
125138
str, COLL_UCC_CTS_STR);

0 commit comments

Comments
 (0)