Skip to content

Commit 8039356

Browse files
author
Valentin Petrov
committed
OMPI/COLL: coll ucc component
Signed-off-by: Valentin Petrov <valentinp@nvidia.com>
1 parent 8b850d2 commit 8039356

15 files changed

+1580
-1
lines changed

config/ompi_check_ucc.m4

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
dnl -*- shell-script -*-
2+
dnl
3+
dnl Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
dnl Copyright (c) 2013 Cisco Systems, Inc. All rights reserved.
5+
dnl Copyright (c) 2015 Research Organization for Information Science
6+
dnl and Technology (RIST). All rights reserved.
7+
dnl $COPYRIGHT$
8+
dnl
9+
dnl Additional copyrights may follow
10+
dnl
11+
dnl $HEADER$
12+
dnl
13+
14+
# OMPI_CHECK_UCC(prefix, [action-if-found], [action-if-not-found])
15+
# --------------------------------------------------------
16+
# check if ucc support can be found. sets prefix_{CPPFLAGS,
17+
# LDFLAGS, LIBS} as needed and runs action-if-found if there is
18+
# support, otherwise executes action-if-not-found
19+
AC_DEFUN([OMPI_CHECK_UCC],[
20+
OPAL_VAR_SCOPE_PUSH([ompi_check_ucc_dir ompi_check_ucc_libs ompi_check_ucc_happy CPPFLAGS_save LDFLAGS_save LIBS_save])
21+
22+
AC_ARG_WITH([ucc],
23+
[AC_HELP_STRING([--with-ucc(=DIR)],
24+
[Build UCC (Unified Collective Communication)])])
25+
26+
AS_IF([test "$with_ucc" != "no"],
27+
[ompi_check_ucc_libs=ucc
28+
AS_IF([test ! -z "$with_ucc" && test "$with_ucc" != "yes"],
29+
[ompi_check_ucc_dir=$with_ucc])
30+
31+
CPPFLAGS_save=$CPPFLAGS
32+
LDFLAGS_save=$LDFLAGS
33+
LIBS_save=$LIBS
34+
35+
OPAL_LOG_MSG([$1_CPPFLAGS : $$1_CPPFLAGS], 1)
36+
OPAL_LOG_MSG([$1_LDFLAGS : $$1_LDFLAGS], 1)
37+
OPAL_LOG_MSG([$1_LIBS : $$1_LIBS], 1)
38+
39+
OPAL_CHECK_PACKAGE([$1],
40+
[ucc/api/ucc.h],
41+
[$ompi_check_ucc_libs],
42+
[ucc_init_version],
43+
[],
44+
[$ompi_check_ucc_dir],
45+
[],
46+
[ompi_check_ucc_happy="yes"],
47+
[ompi_check_ucc_happy="no"])
48+
49+
AS_IF([test "$ompi_check_ucc_happy" = "yes"],
50+
[
51+
CPPFLAGS=$coll_ucc_CPPFLAGS
52+
LDFLAGS=$coll_ucc_LDFLAGS
53+
LIBS=$coll_ucc_LIBS
54+
AC_CHECK_FUNCS(ucc_comm_free, [], [])
55+
],
56+
[])
57+
58+
CPPFLAGS=$CPPFLAGS_save
59+
LDFLAGS=$LDFLAGS_save
60+
LIBS=$LIBS_save],
61+
[ompi_check_ucc_happy=no])
62+
63+
AS_IF([test "$ompi_check_ucc_happy" = "yes" && test "$enable_progress_threads" = "yes"],
64+
[AC_MSG_WARN([ucc driver does not currently support progress threads. Disabling UCC.])
65+
ompi_check_ucc_happy="no"])
66+
67+
AS_IF([test "$ompi_check_ucc_happy" = "yes"],
68+
[$2],
69+
[AS_IF([test ! -z "$with_ucc" && test "$with_ucc" != "no"],
70+
[AC_MSG_ERROR([UCC support requested but not found. Aborting])])
71+
$3])
72+
73+
OPAL_VAR_SCOPE_POP
74+
])

ompi/mca/coll/base/coll_tags.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@
5959
/* one extra reserved to avoid revoke for normal reqs, see request/req_ft.c*/
6060
#define MCA_COLL_BASE_TAG_FT_END (MCA_COLL_BASE_TAG_FT_BASE - 3)
6161

62-
#define MCA_COLL_BASE_TAG_STATIC_END MCA_COLL_BASE_TAG_FT_END
62+
#define MCA_COLL_BASE_TAG_UCC (MCA_COLL_BASE_TAG_FT_END - 1)
63+
64+
#define MCA_COLL_BASE_TAG_STATIC_END (MCA_COLL_BASE_TAG_UCC - 1)
6365

6466

6567

ompi/mca/coll/ucc/Makefile.am

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- shell-script -*-
2+
#
3+
#
4+
# Copyright (c) 2021 Mellanox Technologies. All rights reserved.
5+
# $COPYRIGHT$
6+
#
7+
# Additional copyrights may follow
8+
#
9+
# $HEADER$
10+
#
11+
#
12+
13+
AM_CPPFLAGS = $(coll_ucc_CPPFLAGS)
14+
15+
coll_ucc_sources = \
16+
coll_ucc.h \
17+
coll_ucc_debug.h \
18+
coll_ucc_dtypes.h \
19+
coll_ucc_module.c \
20+
coll_ucc_component.c \
21+
coll_ucc_barrier.c \
22+
coll_ucc_bcast.c \
23+
coll_ucc_allreduce.c \
24+
coll_ucc_alltoall.c \
25+
coll_ucc_alltoallv.c
26+
27+
# Make the output library in this directory, and name it either
28+
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la
29+
# (for static builds).
30+
31+
if MCA_BUILD_ompi_coll_ucc_DSO
32+
component_noinst =
33+
component_install = mca_coll_ucc.la
34+
else
35+
component_noinst = libmca_coll_ucc.la
36+
component_install =
37+
endif
38+
39+
mcacomponentdir = $(ompilibdir)
40+
mcacomponent_LTLIBRARIES = $(component_install)
41+
mca_coll_ucc_la_SOURCES = $(coll_ucc_sources)
42+
mca_coll_ucc_la_LIBADD = $(top_builddir)/ompi/lib@OMPI_LIBMPI_NAME@.la \
43+
$(coll_ucc_LIBS)
44+
mca_coll_ucc_la_LDFLAGS = -module -avoid-version $(coll_ucc_LDFLAGS)
45+
46+
noinst_LTLIBRARIES = $(component_noinst)
47+
libmca_coll_ucc_la_SOURCES = $(coll_ucc_sources)
48+
libmca_coll_ucc_la_LIBADD = $(coll_ucc_LIBS)
49+
libmca_coll_ucc_la_LDFLAGS = -module -avoid-version $(coll_ucc_LDFLAGS)

ompi/mca/coll/ucc/coll_ucc.h

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/**
2+
Copyright (c) 2021 Mellanox Technologies. All rights reserved.
3+
$COPYRIGHT$
4+
5+
Additional copyrights may follow
6+
7+
$HEADER$
8+
*/
9+
10+
#ifndef MCA_COLL_UCC_H
11+
#define MCA_COLL_UCC_H
12+
13+
#include "ompi_config.h"
14+
#include "mpi.h"
15+
#include "ompi/mca/mca.h"
16+
#include "opal/memoryhooks/memory.h"
17+
#include "opal/mca/memory/base/base.h"
18+
#include "ompi/mca/coll/coll.h"
19+
#include "ompi/communicator/communicator.h"
20+
#include "ompi/attribute/attribute.h"
21+
#include "ompi/op/op.h"
22+
#include "coll_ucc_debug.h"
23+
#include <ucc/api/ucc.h>
24+
25+
BEGIN_C_DECLS
26+
27+
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
28+
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
29+
UCC_COLL_TYPE_ALLTOALLV)
30+
31+
#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv," \
32+
"ibarrier,ibcast,iallreduce,ialltoall,ialltoallv"
33+
34+
typedef struct mca_coll_ucc_req {
35+
ompi_request_t super;
36+
ucc_coll_req_h ucc_req;
37+
} mca_coll_ucc_req_t;
38+
OBJ_CLASS_DECLARATION(mca_coll_ucc_req_t);
39+
40+
struct mca_coll_ucc_component_t {
41+
mca_coll_base_component_2_4_0_t super;
42+
int ucc_priority;
43+
int ucc_verbose;
44+
int ucc_enable;
45+
int ucc_np;
46+
char *cls;
47+
char *cts;
48+
const char *compiletime_version;
49+
const char *runtime_version;
50+
bool libucc_initialized;
51+
ucc_lib_h ucc_lib;
52+
ucc_lib_attr_t ucc_lib_attr;
53+
ucc_coll_type_t cts_requested;
54+
ucc_coll_type_t nb_cts_requested;
55+
ucc_context_h ucc_context;
56+
opal_free_list_t requests;
57+
};
58+
typedef struct mca_coll_ucc_component_t mca_coll_ucc_component_t;
59+
60+
OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
61+
62+
/**
63+
* UCC enabled communicator
64+
*/
65+
struct mca_coll_ucc_module_t {
66+
mca_coll_base_module_t super;
67+
ompi_communicator_t* comm;
68+
int rank;
69+
ucc_team_h ucc_team;
70+
mca_coll_base_module_allreduce_fn_t previous_allreduce;
71+
mca_coll_base_module_t* previous_allreduce_module;
72+
mca_coll_base_module_iallreduce_fn_t previous_iallreduce;
73+
mca_coll_base_module_t* previous_iallreduce_module;
74+
mca_coll_base_module_barrier_fn_t previous_barrier;
75+
mca_coll_base_module_t* previous_barrier_module;
76+
mca_coll_base_module_ibarrier_fn_t previous_ibarrier;
77+
mca_coll_base_module_t* previous_ibarrier_module;
78+
mca_coll_base_module_bcast_fn_t previous_bcast;
79+
mca_coll_base_module_t* previous_bcast_module;
80+
mca_coll_base_module_ibcast_fn_t previous_ibcast;
81+
mca_coll_base_module_t* previous_ibcast_module;
82+
mca_coll_base_module_alltoall_fn_t previous_alltoall;
83+
mca_coll_base_module_t* previous_alltoall_module;
84+
mca_coll_base_module_ialltoall_fn_t previous_ialltoall;
85+
mca_coll_base_module_t* previous_ialltoall_module;
86+
mca_coll_base_module_alltoallv_fn_t previous_alltoallv;
87+
mca_coll_base_module_t* previous_alltoallv_module;
88+
mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv;
89+
mca_coll_base_module_t* previous_ialltoallv_module;
90+
};
91+
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
92+
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
93+
94+
int mca_coll_ucc_init_query(bool enable_progress_threads, bool enable_mpi_threads);
95+
mca_coll_base_module_t *mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority);
96+
97+
int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, int count,
98+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
99+
struct ompi_communicator_t *comm,
100+
mca_coll_base_module_t *module);
101+
102+
int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
103+
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
104+
struct ompi_communicator_t *comm,
105+
ompi_request_t** request,
106+
mca_coll_base_module_t *module);
107+
108+
int mca_coll_ucc_barrier(struct ompi_communicator_t *comm,
109+
mca_coll_base_module_t *module);
110+
111+
int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm,
112+
ompi_request_t** request,
113+
mca_coll_base_module_t *module);
114+
115+
int mca_coll_ucc_bcast(void *buf, int count, struct ompi_datatype_t *dtype,
116+
int root, struct ompi_communicator_t *comm,
117+
mca_coll_base_module_t *module);
118+
119+
int mca_coll_ucc_ibcast(void *buf, int count, struct ompi_datatype_t *dtype,
120+
int root, struct ompi_communicator_t *comm,
121+
ompi_request_t** request,
122+
mca_coll_base_module_t *module);
123+
124+
int mca_coll_ucc_alltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
125+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
126+
struct ompi_communicator_t *comm,
127+
mca_coll_base_module_t *module);
128+
129+
int mca_coll_ucc_ialltoall(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
130+
void* rbuf, int rcount, struct ompi_datatype_t *rdtype,
131+
struct ompi_communicator_t *comm,
132+
ompi_request_t** request,
133+
mca_coll_base_module_t *module);
134+
135+
int mca_coll_ucc_alltoallv(const void *sbuf, const int *scounts, const int *sdips,
136+
struct ompi_datatype_t *sdtype,
137+
void* rbuf, const int *rcounts, const int *rdisps,
138+
struct ompi_datatype_t *rdtype,
139+
struct ompi_communicator_t *comm,
140+
mca_coll_base_module_t *module);
141+
142+
int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts, const int *sdips,
143+
struct ompi_datatype_t *sdtype,
144+
void* rbuf, const int *rcounts, const int *rdisps,
145+
struct ompi_datatype_t *rdtype,
146+
struct ompi_communicator_t *comm,
147+
ompi_request_t** request,
148+
mca_coll_base_module_t *module);
149+
END_C_DECLS
150+
#endif
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
2+
/**
3+
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
4+
* $COPYRIGHT$
5+
*
6+
* Additional copyrights may follow
7+
*
8+
*/
9+
10+
#include "coll_ucc_common.h"
11+
12+
static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, int count,
13+
struct ompi_datatype_t *dtype,
14+
struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module,
15+
ucc_coll_req_h *req,
16+
mca_coll_ucc_req_t *coll_req)
17+
{
18+
ucc_datatype_t ucc_dt;
19+
ucc_reduction_op_t ucc_op;
20+
21+
ucc_dt = ompi_dtype_to_ucc_dtype(dtype);
22+
ucc_op = ompi_op_to_ucc_op(op);
23+
if (OPAL_UNLIKELY(COLL_UCC_DT_UNSUPPORTED == ucc_dt)) {
24+
UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s",
25+
dtype->super.name);
26+
goto fallback;
27+
}
28+
if (OPAL_UNLIKELY(COLL_UCC_OP_UNSUPPORTED == ucc_op)) {
29+
UCC_VERBOSE(5, "ompi_op is not supported: op = %s",
30+
op->o_name);
31+
goto fallback;
32+
}
33+
ucc_coll_args_t coll = {
34+
.mask = UCC_COLL_ARGS_FIELD_PREDEFINED_REDUCTIONS,
35+
.coll_type = UCC_COLL_TYPE_ALLREDUCE,
36+
.src.info = {
37+
.buffer = (void*)sbuf,
38+
.count = count,
39+
.datatype = ucc_dt,
40+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
41+
},
42+
.dst.info = {
43+
.buffer = rbuf,
44+
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
45+
},
46+
.reduce = {
47+
.predefined_op = ucc_op,
48+
},
49+
};
50+
if (MPI_IN_PLACE == sbuf) {
51+
coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
52+
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
53+
}
54+
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
55+
return UCC_OK;
56+
fallback:
57+
return UCC_ERR_NOT_SUPPORTED;
58+
}
59+
60+
int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, int count,
61+
struct ompi_datatype_t *dtype,
62+
struct ompi_op_t *op, struct ompi_communicator_t *comm,
63+
mca_coll_base_module_t *module)
64+
{
65+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
66+
ucc_coll_req_h req;
67+
68+
UCC_VERBOSE(3, "running ucc allreduce");
69+
COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op,
70+
ucc_module, &req, NULL));
71+
COLL_UCC_CHECK(ucc_collective_post(req));
72+
COLL_UCC_CHECK(coll_ucc_req_wait(req));
73+
return OMPI_SUCCESS;
74+
fallback:
75+
UCC_VERBOSE(3, "running fallback allreduce");
76+
return ucc_module->previous_allreduce(sbuf, rbuf, count, dtype, op,
77+
comm, ucc_module->previous_allreduce_module);
78+
}
79+
80+
int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
81+
struct ompi_datatype_t *dtype,
82+
struct ompi_op_t *op, struct ompi_communicator_t *comm,
83+
ompi_request_t** request,
84+
mca_coll_base_module_t *module)
85+
{
86+
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module;
87+
ucc_coll_req_h req;
88+
mca_coll_ucc_req_t *coll_req;
89+
90+
UCC_VERBOSE(3, "running ucc iallreduce");
91+
COLL_UCC_GET_REQ(coll_req);
92+
COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op,
93+
ucc_module, &req, coll_req));
94+
COLL_UCC_CHECK(ucc_collective_post(req));
95+
*request = &coll_req->super;
96+
return OMPI_SUCCESS;
97+
fallback:
98+
UCC_VERBOSE(3, "running fallback allreduce");
99+
return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op,
100+
comm, request, ucc_module->previous_iallreduce_module);
101+
}

0 commit comments

Comments
 (0)