Skip to content

Commit f29301a

Browse files
coll/ucc: add scatter
Signed-off-by: Sergey Lebedev <sergeyle@nvidia.com>
1 parent 88d3a3b commit f29301a

File tree

5 files changed

+154
-3
lines changed

5 files changed

+154
-3
lines changed

ompi/mca/coll/ucc/Makefile.am

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ coll_ucc_sources = \
3232
coll_ucc_reduce.c \
3333
coll_ucc_reduce_scatter_block.c \
3434
coll_ucc_reduce_scatter.c \
35+
coll_ucc_scatter.c \
3536
coll_ucc_scatterv.c
3637

3738
# Make the output library in this directory, and name it either

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ BEGIN_C_DECLS
3131
UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \
3232
UCC_COLL_TYPE_GATHER | UCC_COLL_TYPE_GATHERV | \
3333
UCC_COLL_TYPE_REDUCE_SCATTER | UCC_COLL_TYPE_REDUCE_SCATTERV | \
34-
UCC_COLL_TYPE_SCATTERV)
34+
UCC_COLL_TYPE_SCATTERV | UCC_COLL_TYPE_SCATTER)
3535

3636
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather," \
3737
"allgatherv,reduce,gather,gatherv,reduce_scatter_block,"\
38-
"reduce_scatter,scatterv," \
38+
"reduce_scatter,scatterv,scatter," \
3939
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,"\
4040
"iallgatherv,ireduce,igather,igatherv,ireduce_scatter_block,"\
41-
"ireduce_scatter,iscatterv"
41+
"ireduce_scatter,iscatterv,iscatter"
4242

4343
typedef struct mca_coll_ucc_req {
4444
ompi_request_t super;
@@ -128,6 +128,10 @@ struct mca_coll_ucc_module_t {
128128
mca_coll_base_module_t* previous_scatterv_module;
129129
mca_coll_base_module_iscatterv_fn_t previous_iscatterv;
130130
mca_coll_base_module_t* previous_iscatterv_module;
131+
mca_coll_base_module_scatter_fn_t previous_scatter;
132+
mca_coll_base_module_t* previous_scatter_module;
133+
mca_coll_base_module_iscatter_fn_t previous_iscatter;
134+
mca_coll_base_module_t* previous_iscatter_module;
131135
};
132136
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
133137
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
@@ -288,5 +292,18 @@ int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts,
288292
ompi_request_t** request,
289293
mca_coll_base_module_t *module);
290294

295+
int mca_coll_ucc_scatter(const void *sbuf, int scount,
296+
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
297+
struct ompi_datatype_t *rdtype, int root,
298+
struct ompi_communicator_t *comm,
299+
mca_coll_base_module_t *module);
300+
301+
int mca_coll_ucc_iscatter(const void *sbuf, int scount,
302+
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
303+
struct ompi_datatype_t *rdtype, int root,
304+
struct ompi_communicator_t *comm,
305+
ompi_request_t** request,
306+
mca_coll_base_module_t *module);
307+
291308
END_C_DECLS
292309
#endif

ompi/mca/coll/ucc/coll_ucc_component.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str)
131131
return UCC_COLL_TYPE_REDUCE_SCATTERV;
132132
} else if (0 == strcasecmp(str, "scatterv")) {
133133
return UCC_COLL_TYPE_SCATTERV;
134+
} else if (0 == strcasecmp(str, "scatter")) {
135+
return UCC_COLL_TYPE_SCATTER;
134136
}
135137
UCC_ERROR("incorrect value for cts: %s, allowed: %s",
136138
str, COLL_UCC_CTS_STR);

ompi/mca/coll/ucc/coll_ucc_module.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module)
8383
ucc_module->previous_scatterv_module = NULL;
8484
ucc_module->previous_iscatterv = NULL;
8585
ucc_module->previous_iscatterv_module = NULL;
86+
ucc_module->previous_scatter = NULL;
87+
ucc_module->previous_scatter_module = NULL;
88+
ucc_module->previous_iscatter = NULL;
89+
ucc_module->previous_iscatter_module = NULL;
8690
}
8791

8892
static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module)
@@ -129,6 +133,8 @@ static void mca_coll_ucc_module_destruct(mca_coll_ucc_module_t *ucc_module)
129133
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_ireduce_scatter_module);
130134
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_scatterv_module);
131135
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iscatterv_module);
136+
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_scatter_module);
137+
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_iscatter_module);
132138
mca_coll_ucc_module_clear(ucc_module);
133139
}
134140

@@ -170,6 +176,8 @@ static int mca_coll_ucc_save_coll_handlers(mca_coll_ucc_module_t *ucc_module)
170176
SAVE_PREV_COLL_API(ireduce_scatter);
171177
SAVE_PREV_COLL_API(scatterv);
172178
SAVE_PREV_COLL_API(iscatterv);
179+
SAVE_PREV_COLL_API(scatter);
180+
SAVE_PREV_COLL_API(iscatter);
173181
return OMPI_SUCCESS;
174182
}
175183

@@ -558,6 +566,7 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority)
558566
SET_COLL_PTR(ucc_module, REDUCE_SCATTER, reduce_scatter_block);
559567
SET_COLL_PTR(ucc_module, REDUCE_SCATTERV, reduce_scatter);
560568
SET_COLL_PTR(ucc_module, SCATTERV, scatterv);
569+
SET_COLL_PTR(ucc_module, SCATTER, scatter);
561570
return &ucc_module->super;
562571
}
563572

ompi/mca/coll/ucc/coll_ucc_scatter.c

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/**
2+
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
3+
* Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
4+
* $COPYRIGHT$
5+
*
6+
* Additional copyrights may follow
7+
*
8+
*/
9+
10+
#include "coll_ucc_common.h"
11+
12+
static inline
13+
ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, int scount,
14+
struct ompi_datatype_t *sdtype,
15+
void *rbuf, int rcount,
16+
struct ompi_datatype_t *rdtype, int root,
17+
mca_coll_ucc_module_t *ucc_module,
18+
ucc_coll_req_h *req,
19+
mca_coll_ucc_req_t *coll_req)
20+
{
21+
ucc_datatype_t ucc_sdt, ucc_rdt;
22+
int comm_rank = ompi_comm_rank(ucc_module->comm);
23+
int comm_size = ompi_comm_size(ucc_module->comm);
24+
25+
ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype);
26+
if (comm_rank == root) {
27+
ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype);
28+
if ((COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ||
29+
(MPI_IN_PLACE != rbuf && COLL_UCC_DT_UNSUPPORTED == ucc_rdt)) {
30+
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
31+
(COLL_UCC_DT_UNSUPPORTED == ucc_sdt) ?
32+
sdtype->super.name : rdtype->super.name);
33+
goto fallback;
34+
}
35+
} else {
36+
if (COLL_UCC_DT_UNSUPPORTED == ucc_rdt) {
37+
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
38+
rdtype->super.name);
39+
goto fallback;
40+
}
41+
}
42+
43+
ucc_coll_args_t coll = {
44+
.mask = 0,
45+
.coll_type = UCC_COLL_TYPE_SCATTER,
46+
.root = root,
47+
.src.info = {
48+
.buffer = (void*)sbuf,
49+
.count = ((size_t)scount) * comm_size,
50+
.datatype = ucc_sdt,
51+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
52+
},
53+
.dst.info = {
54+
.buffer = (void*)rbuf,
55+
.count = rcount,
56+
.datatype = ucc_rdt,
57+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
58+
},
59+
};
60+
61+
if (MPI_IN_PLACE == rbuf) {
62+
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
63+
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
64+
}
65+
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
66+
return UCC_OK;
67+
fallback:
68+
return UCC_ERR_NOT_SUPPORTED;
69+
}
70+
71+
int mca_coll_ucc_scatter(const void *sbuf, int scount,
72+
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
73+
struct ompi_datatype_t *rdtype, int root,
74+
struct ompi_communicator_t *comm,
75+
mca_coll_base_module_t *module)
76+
{
77+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
78+
ucc_coll_req_h req;
79+
80+
UCC_VERBOSE(3, "running ucc scatter");
81+
COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount,
82+
rdtype, root, ucc_module, &req,
83+
NULL));
84+
COLL_UCC_POST_AND_CHECK(req);
85+
COLL_UCC_CHECK(coll_ucc_req_wait(req));
86+
return OMPI_SUCCESS;
87+
fallback:
88+
UCC_VERBOSE(3, "running fallback scatter");
89+
return ucc_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount,
90+
rdtype, root, comm,
91+
ucc_module->previous_scatter_module);
92+
93+
}
94+
95+
int mca_coll_ucc_iscatter(const void *sbuf, int scount,
96+
struct ompi_datatype_t *sdtype, void *rbuf, int rcount,
97+
struct ompi_datatype_t *rdtype, int root,
98+
struct ompi_communicator_t *comm,
99+
ompi_request_t** request,
100+
mca_coll_base_module_t *module)
101+
{
102+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
103+
ucc_coll_req_h req;
104+
mca_coll_ucc_req_t *coll_req = NULL;
105+
106+
UCC_VERBOSE(3, "running ucc iscatter");
107+
COLL_UCC_GET_REQ(coll_req);
108+
COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount,
109+
rdtype, root, ucc_module, &req,
110+
coll_req));
111+
COLL_UCC_POST_AND_CHECK(req);
112+
*request = &coll_req->super;
113+
return OMPI_SUCCESS;
114+
fallback:
115+
UCC_VERBOSE(3, "running fallback iscatter");
116+
if (coll_req) {
117+
mca_coll_ucc_req_free((ompi_request_t **)&coll_req);
118+
}
119+
return ucc_module->previous_iscatter(sbuf, scount, sdtype, rbuf, rcount,
120+
rdtype, root, comm, request,
121+
ucc_module->previous_iscatter_module);
122+
}

0 commit comments

Comments
 (0)