Skip to content

Commit da9b92a

Browse files
authored
Merge pull request #9152 from lappazos/Add_Allgather
OMPI/COLL: Add allgather, reduce coll ucc
2 parents 8e51a72 + 31dc76e commit da9b92a

14 files changed

+500
-71
lines changed

ompi/mca/coll/ucc/Makefile.am

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ coll_ucc_sources = \
2121
coll_ucc_barrier.c \
2222
coll_ucc_bcast.c \
2323
coll_ucc_allreduce.c \
24+
coll_ucc_reduce.c \
2425
coll_ucc_alltoall.c \
25-
coll_ucc_alltoallv.c
26+
coll_ucc_alltoallv.c \
27+
coll_ucc_allgather.c \
28+
coll_ucc_allgatherv.c
2629

2730
# Make the output library in this directory, and name it either
2831
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ BEGIN_C_DECLS
2626

2727
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
2828
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
29-
UCC_COLL_TYPE_ALLTOALLV)
29+
UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
30+
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV)
3031

31-
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv," \
32-
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv"
32+
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \
33+
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce"
3334

3435
typedef struct mca_coll_ucc_req {
3536
ompi_request_t super;
@@ -63,30 +64,42 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
6364
* UCC enabled communicator
6465
*/
6566
struct mca_coll_ucc_module_t {
66-
mca_coll_base_module_t super;
67-
ompi_communicator_t* comm;
68-
int rank;
69-
ucc_team_h ucc_team;
70-
mca_coll_base_module_allreduce_fn_t previous_allreduce;
71-
mca_coll_base_module_t* previous_allreduce_module;
72-
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
73-
mca_coll_base_module_t* previous_iallreduce_module;
74-
mca_coll_base_module_barrier_fn_t previous_barrier;
75-
mca_coll_base_module_t* previous_barrier_module;
76-
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
77-
mca_coll_base_module_t* previous_ibarrier_module;
78-
mca_coll_base_module_bcast_fn_t previous_bcast;
79-
mca_coll_base_module_t* previous_bcast_module;
80-
mca_coll_base_module_ibcast_fn_t previous_ibcast;
81-
mca_coll_base_module_t* previous_ibcast_module;
82-
mca_coll_base_module_alltoall_fn_t previous_alltoall;
83-
mca_coll_base_module_t* previous_alltoall_module;
84-
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
85-
mca_coll_base_module_t* previous_ialltoall_module;
86-
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
87-
mca_coll_base_module_t* previous_alltoallv_module;
88-
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
89-
mca_coll_base_module_t* previous_ialltoallv_module;
67+
mca_coll_base_module_t super;
68+
ompi_communicator_t* comm;
69+
int rank;
70+
ucc_team_h ucc_team;
71+
mca_coll_base_module_allreduce_fn_t previous_allreduce;
72+
mca_coll_base_module_t* previous_allreduce_module;
73+
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
74+
mca_coll_base_module_t* previous_iallreduce_module;
75+
mca_coll_base_module_reduce_fn_t previous_reduce;
76+
mca_coll_base_module_t* previous_reduce_module;
77+
mca_coll_base_module_ireduce_fn_t previous_ireduce;
78+
mca_coll_base_module_t* previous_ireduce_module;
79+
mca_coll_base_module_barrier_fn_t previous_barrier;
80+
mca_coll_base_module_t* previous_barrier_module;
81+
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
82+
mca_coll_base_module_t* previous_ibarrier_module;
83+
mca_coll_base_module_bcast_fn_t previous_bcast;
84+
mca_coll_base_module_t* previous_bcast_module;
85+
mca_coll_base_module_ibcast_fn_t previous_ibcast;
86+
mca_coll_base_module_t* previous_ibcast_module;
87+
mca_coll_base_module_alltoall_fn_t previous_alltoall;
88+
mca_coll_base_module_t* previous_alltoall_module;
89+
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
90+
mca_coll_base_module_t* previous_ialltoall_module;
91+
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
92+
mca_coll_base_module_t* previous_alltoallv_module;
93+
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
94+
mca_coll_base_module_t* previous_ialltoallv_module;
95+
mca_coll_base_module_allgather_fn_t previous_allgather;
96+
mca_coll_base_module_t* previous_allgather_module;
97+
mca_coll_base_module_iallgather_fn_t previous_iallgather;
98+
mca_coll_base_module_t* previous_iallgather_module;
99+
mca_coll_base_module_allgatherv_fn_t previous_allgatherv;
100+
mca_coll_base_module_t* previous_allgatherv_module;
101+
mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv;
102+
mca_coll_base_module_t* previous_iallgatherv_module;
90103
};
91104
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
92105
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
@@ -105,6 +118,17 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
105118
ompi_request_t** request,
106119
mca_coll_base_module_t *module);
107120

121+
int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, int count,
122+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
123+
int root, struct ompi_communicator_t *comm,
124+
struct mca_coll_base_module_2_4_0_t *module);
125+
126+
int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, int count,
127+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
128+
int root, struct ompi_communicator_t *comm,
129+
ompi_request_t** request,
130+
struct mca_coll_base_module_2_4_0_t *module);
131+
108132
int mca_coll_ucc_barrier(struct ompi_communicator_t *comm,
109133
mca_coll_base_module_t *module);
110134

@@ -146,5 +170,30 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts, const int *sdi
146170
struct ompi_communicator_t *comm,
147171
ompi_request_t** request,
148172
mca_coll_base_module_t *module);
173+
174+
int mca_coll_ucc_allgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
175+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
176+
struct ompi_communicator_t *comm,
177+
mca_coll_base_module_t *module);
178+
179+
int mca_coll_ucc_iallgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
180+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
181+
struct ompi_communicator_t *comm,
182+
ompi_request_t** request,
183+
mca_coll_base_module_t *module);
184+
185+
int mca_coll_ucc_allgatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
186+
void* rbuf, const int *rcounts, const int *rdisps,
187+
struct ompi_datatype_t *rdtype,
188+
struct ompi_communicator_t *comm,
189+
mca_coll_base_module_t *module);
190+
191+
int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
192+
void* rbuf, const int *rcounts, const int *rdisps,
193+
struct ompi_datatype_t *rdtype,
194+
struct ompi_communicator_t *comm,
195+
ompi_request_t** request,
196+
mca_coll_base_module_t *module);
197+
149198
END_C_DECLS
150199
#endif
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
2+
/**
3+
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
* $COPYRIGHT$
5+
*
6+
* Additional copyrights may follow
7+
*
8+
*/
9+
10+
#include "coll_ucc_common.h"
11+
12+
static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
13+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
14+
mca_coll_ucc_module_t *ucc_module,
15+
ucc_coll_req_h *req,
16+
mca_coll_ucc_req_t *coll_req)
17+
{
18+
ucc_datatype_t ucc_sdt, ucc_rdt;
19+
int comm_size = ompi_comm_size(ucc_module->comm);
20+
21+
if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount) ||
22+
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
23+
goto fallback;
24+
}
25+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
26+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
27+
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
28+
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
29+
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
30+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
31+
sdtype->super.name : rdtype->super.name);
32+
goto fallback;
33+
}
34+
35+
ucc_coll_args_t coll = {
36+
.mask = 0,
37+
.coll_type = UCC_COLL_TYPE_ALLGATHER,
38+
.src.info = {
39+
.buffer = (void*)sbuf,
40+
.count = scount,
41+
.datatype = ucc_sdt,
42+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
43+
},
44+
.dst.info = {
45+
.buffer = (void*)rbuf,
46+
.count = rcount * comm_size,
47+
.datatype = ucc_rdt,
48+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
49+
}
50+
};
51+
52+
if (MPI_IN_PLACE == sbuf) {
53+
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
54+
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
55+
}
56+
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
57+
return UCC_OK;
58+
fallback:
59+
return UCC_ERR_NOT_SUPPORTED;
60+
}
61+
62+
int mca_coll_ucc_allgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
63+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
64+
struct ompi_communicator_t *comm,
65+
mca_coll_base_module_t *module)
66+
{
67+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
68+
ucc_coll_req_h req;
69+
70+
UCC_VERBOSE(3, "running ucc allgather");
71+
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
72+
rbuf, rcount, rdtype,
73+
ucc_module, &req, NULL));
74+
COLL_UCC_POST_AND_CHECK(req);
75+
COLL_UCC_CHECK(coll_ucc_req_wait(req));
76+
return OMPI_SUCCESS;
77+
fallback:
78+
UCC_VERBOSE(3, "running fallback allgather");
79+
return ucc_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
80+
comm, ucc_module->previous_allgather_module);
81+
}
82+
83+
int mca_coll_ucc_iallgather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
84+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
85+
struct ompi_communicator_t *comm,
86+
ompi_request_t** request,
87+
mca_coll_base_module_t *module)
88+
{
89+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
90+
ucc_coll_req_h req;
91+
mca_coll_ucc_req_t *coll_req = NULL;
92+
93+
UCC_VERBOSE(3, "running ucc iallgather");
94+
COLL_UCC_GET_REQ(coll_req);
95+
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
96+
rbuf, rcount, rdtype,
97+
ucc_module, &req, coll_req));
98+
COLL_UCC_POST_AND_CHECK(req);
99+
*request = &coll_req->super;
100+
return OMPI_SUCCESS;
101+
fallback:
102+
UCC_VERBOSE(3, "running fallback iallgather");
103+
if (coll_req) {
104+
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
105+
}
106+
return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
107+
comm, request, ucc_module->previous_iallgather_module);
108+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
2+
/**
3+
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
* $COPYRIGHT$
5+
*
6+
* Additional copyrights may follow
7+
*
8+
*/
9+
10+
#include "coll_ucc_common.h"
11+
12+
static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, int scount,
13+
struct ompi_datatype_t *sdtype,
14+
void* rbuf, const int *rcounts, const int *rdisps,
15+
struct ompi_datatype_t *rdtype,
16+
mca_coll_ucc_module_t *ucc_module,
17+
ucc_coll_req_h *req,
18+
mca_coll_ucc_req_t *coll_req)
19+
{
20+
ucc_datatype_t ucc_sdt, ucc_rdt;
21+
22+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
23+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
24+
if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt ||
25+
COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
26+
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
27+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
28+
sdtype->super.name : rdtype->super.name);
29+
goto fallback;
30+
}
31+
32+
ucc_coll_args_t coll = {
33+
.mask = 0,
34+
.coll_type = UCC_COLL_TYPE_ALLGATHERV,
35+
.src.info = {
36+
.buffer = (void*)sbuf,
37+
.count = scount,
38+
.datatype = ucc_sdt,
39+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
40+
},
41+
.dst.info_v = {
42+
.buffer = (void*)rbuf,
43+
.counts = (ucc_count_t*)rcounts,
44+
.displacements = (ucc_aint_t*)rdisps,
45+
.datatype = ucc_rdt,
46+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
47+
}
48+
};
49+
50+
if (MPI_IN_PLACE == sbuf) {
51+
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
52+
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
53+
}
54+
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
55+
return UCC_OK;
56+
fallback:
57+
return UCC_ERR_NOT_SUPPORTED;
58+
}
59+
60+
int mca_coll_ucc_allgatherv(const void *sbuf, int scount,
61+
struct ompi_datatype_t *sdtype,
62+
void* rbuf, const int *rcounts, const int *rdisps,
63+
struct ompi_datatype_t *rdtype,
64+
struct ompi_communicator_t *comm,
65+
mca_coll_base_module_t *module)
66+
{
67+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
68+
ucc_coll_req_h req;
69+
70+
UCC_VERBOSE(3, "running ucc allgatherv");
71+
72+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
73+
rbuf, rcounts, rdisps, rdtype,
74+
ucc_module, &req, NULL));
75+
COLL_UCC_POST_AND_CHECK(req);
76+
COLL_UCC_CHECK(coll_ucc_req_wait(req));
77+
return OMPI_SUCCESS;
78+
fallback:
79+
UCC_VERBOSE(3, "running fallback allgatherv");
80+
return ucc_module->previous_allgatherv(sbuf, scount, sdtype,
81+
rbuf, rcounts, rdisps, rdtype,
82+
comm, ucc_module->previous_allgatherv_module);
83+
}
84+
85+
int mca_coll_ucc_iallgatherv(const void *sbuf, int scount,
86+
struct ompi_datatype_t *sdtype,
87+
void* rbuf, const int *rcounts, const int *rdisps,
88+
struct ompi_datatype_t *rdtype,
89+
struct ompi_communicator_t *comm,
90+
ompi_request_t** request,
91+
mca_coll_base_module_t *module)
92+
{
93+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
94+
ucc_coll_req_h req;
95+
mca_coll_ucc_req_t *coll_req = NULL;
96+
97+
UCC_VERBOSE(3, "running ucc iallgatherv");
98+
COLL_UCC_GET_REQ(coll_req);
99+
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
100+
rbuf, rcounts, rdisps, rdtype,
101+
ucc_module, &req, coll_req));
102+
COLL_UCC_POST_AND_CHECK(req);
103+
*request = &coll_req->super;
104+
return OMPI_SUCCESS;
105+
fallback:
106+
UCC_VERBOSE(3, "running fallback iallgatherv");
107+
if (coll_req) {
108+
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
109+
}
110+
return ucc_module->previous_iallgatherv(sbuf, scount, sdtype,
111+
rbuf, rcounts, rdisps, rdtype,
112+
comm, request, ucc_module->previous_iallgatherv_module);
113+
}

ompi/mca/coll/ucc/coll_ucc_allreduce.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, int count,
6868
UCC_VERBOSE(3, "running ucc allreduce");
6969
COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op,
7070
ucc_module, &req, NULL));
71-
COLL_UCC_CHECK(ucc_collective_post(req));
71+
COLL_UCC_POST_AND_CHECK(req);
7272
COLL_UCC_CHECK(coll_ucc_req_wait(req));
7373
return OMPI_SUCCESS;
7474
fallback:
@@ -85,17 +85,20 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
8585
{
8686
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
8787
ucc_coll_req_h req;
88-
mca_coll_ucc_req_t *coll_req;
88+
mca_coll_ucc_req_t *coll_req = NULL;
8989

9090
UCC_VERBOSE(3, "running ucc iallreduce");
9191
COLL_UCC_GET_REQ(coll_req);
9292
COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op,
9393
ucc_module, &req, coll_req));
94-
COLL_UCC_CHECK(ucc_collective_post(req));
94+
COLL_UCC_POST_AND_CHECK(req);
9595
*request = &coll_req->super;
9696
return OMPI_SUCCESS;
9797
fallback:
98-
UCC_VERBOSE(3, "running fallback allreduce");
98+
UCC_VERBOSE(3, "running fallback iallreduce");
99+
if (coll_req) {
100+
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
101+
}
99102
return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op,
100103
comm, request, ucc_module->previous_iallreduce_module);
101104
}

0 commit comments

Comments
 (0)