Skip to content

Commit f2dfbba

Browse files
committed
Fix collective initialization and finalization
Instead of allowing each collective module to present a list of functions it provide, let them register the functions they provide and save the context of the previous collective if they choose to. There are two major benefits to this approach: - tighter memory management in the collective module themselves. Each collective enable and disable is called exactly once per communicator, to chain or unchain themselves from the collective function pointers struct. The disable is called in the reverse order of the enable, allowing for proper chaining of collectives. - they only install the functions they want. So instead of checking in the coll_select all the functions for all modules, each module can now selectively iterate over only the functions it provides. What is still broken is the ability of a particular collective module to unchain itself in the middle of the execution. Instead, a properly implemented module will have an enable/disable flag, and it should act as a passthrough if it chooses to desactivate. Signed-off-by: George Bosilca <bosilca@icl.utk.edu>
1 parent 49382c3 commit f2dfbba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1734
-1421
lines changed

ompi/mca/coll/accelerator/Makefile.am

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
# Copyright (c) 2014 The University of Tennessee and The University
33
# of Tennessee Research Foundation. All rights
44
# reserved.
5-
# Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
5+
# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
66
# Copyright (c) 2017 IBM Corporation. All rights reserved.
77
# $COPYRIGHT$
88
#
99
# Additional copyrights may follow
1010
#
1111
# $HEADER$
1212
#
13-
dist_ompidata_DATA = help-mpi-coll-accelerator.txt
1413

1514
sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
1615
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \

ompi/mca/coll/accelerator/coll_accelerator.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Copyright (c) 2014 The University of Tennessee and The University
33
* of Tennessee Research Foundation. All rights
44
* reserved.
5-
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
5+
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
66
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
77
* $COPYRIGHT$
88
*
@@ -38,9 +38,6 @@ mca_coll_base_module_t
3838
*mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
3939
int *priority);
4040

41-
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
42-
struct ompi_communicator_t *comm);
43-
4441
int
4542
mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count,
4643
struct ompi_datatype_t *dtype,

ompi/mca/coll/accelerator/coll_accelerator_module.c

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Copyright (c) 2014-2017 The University of Tennessee and The University
33
* of Tennessee Research Foundation. All rights
44
* reserved.
5-
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
5+
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
66
* Copyright (c) 2019 Research Organization for Information Science
77
* and Technology (RIST). All rights reserved.
88
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
@@ -32,30 +32,21 @@
3232
#include "ompi/mca/coll/base/base.h"
3333
#include "coll_accelerator.h"
3434

35+
static int
36+
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
37+
struct ompi_communicator_t *comm);
38+
static int
39+
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
40+
struct ompi_communicator_t *comm);
3541

3642
static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
3743
{
3844
memset(&(module->c_coll), 0, sizeof(module->c_coll));
3945
}
4046

41-
static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module)
42-
{
43-
OBJ_RELEASE(module->c_coll.coll_allreduce_module);
44-
OBJ_RELEASE(module->c_coll.coll_reduce_module);
45-
OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
46-
OBJ_RELEASE(module->c_coll.coll_scatter_module);
47-
/* If the exscan module is not NULL, then this was an
48-
intracommunicator, and therefore scan will have a module as
49-
well. */
50-
if (NULL != module->c_coll.coll_exscan_module) {
51-
OBJ_RELEASE(module->c_coll.coll_exscan_module);
52-
OBJ_RELEASE(module->c_coll.coll_scan_module);
53-
}
54-
}
55-
5647
OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
5748
mca_coll_accelerator_module_construct,
58-
mca_coll_accelerator_module_destruct);
49+
NULL);
5950

6051

6152
/*
@@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
9990

10091
/* Choose whether to use [intra|inter] */
10192
accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
93+
accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable;
10294

103-
accelerator_module->super.coll_allgather = NULL;
104-
accelerator_module->super.coll_allgatherv = NULL;
10595
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
106-
accelerator_module->super.coll_alltoall = NULL;
107-
accelerator_module->super.coll_alltoallv = NULL;
108-
accelerator_module->super.coll_alltoallw = NULL;
109-
accelerator_module->super.coll_barrier = NULL;
110-
accelerator_module->super.coll_bcast = NULL;
111-
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
112-
accelerator_module->super.coll_gather = NULL;
113-
accelerator_module->super.coll_gatherv = NULL;
11496
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
115-
accelerator_module->super.coll_reduce_scatter = NULL;
11697
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
117-
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
118-
accelerator_module->super.coll_scatter = NULL;
119-
accelerator_module->super.coll_scatterv = NULL;
98+
if (!OMPI_COMM_IS_INTER(comm)) {
99+
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
100+
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
101+
}
120102

121103
return &(accelerator_module->super);
122104
}
123105

124106

107+
#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
108+
do \
109+
{ \
110+
if ((__comm)->c_coll->coll_##__api) \
111+
{ \
112+
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
113+
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
114+
} \
115+
else \
116+
{ \
117+
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
118+
"cuda", #__api, ompi_process_info.nodename, \
119+
mca_coll_accelerator_component.priority); \
120+
} \
121+
} while (0)
122+
123+
#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \
124+
do \
125+
{ \
126+
if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \
127+
MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
128+
(__module)->c_coll.coll_##__api##_module = NULL; \
129+
(__module)->c_coll.coll_##__api = NULL; \
130+
} \
131+
} while (0)
132+
125133
/*
126-
* Init module on the communicator
134+
* Init/Fini module on the communicator
127135
*/
128-
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
129-
struct ompi_communicator_t *comm)
136+
static int
137+
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
138+
struct ompi_communicator_t *comm)
130139
{
131-
bool good = true;
132-
char *msg = NULL;
133140
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
134141

135-
#define CHECK_AND_RETAIN(src, dst, name) \
136-
if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
137-
good = false; \
138-
msg = #name; \
139-
} else if (good) { \
140-
(dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
141-
(dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
142-
OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
143-
}
144-
145-
CHECK_AND_RETAIN(comm, s, allreduce);
146-
CHECK_AND_RETAIN(comm, s, reduce);
147-
CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
148-
CHECK_AND_RETAIN(comm, s, scatter);
142+
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
143+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
144+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
149145
if (!OMPI_COMM_IS_INTER(comm)) {
150146
/* MPI does not define scan/exscan on intercommunicators */
151-
CHECK_AND_RETAIN(comm, s, exscan);
152-
CHECK_AND_RETAIN(comm, s, scan);
147+
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
148+
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
153149
}
154150

155-
/* All done */
156-
if (good) {
157-
return OMPI_SUCCESS;
158-
}
159-
opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true,
160-
ompi_process_info.nodename,
161-
mca_coll_accelerator_component.priority, msg);
162-
return OMPI_ERR_NOT_FOUND;
151+
return OMPI_SUCCESS;
163152
}
164153

154+
static int
155+
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
156+
struct ompi_communicator_t *comm)
157+
{
158+
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
159+
160+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
161+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
162+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
163+
if (!OMPI_COMM_IS_INTER(comm))
164+
{
165+
/* MPI does not define scan/exscan on intercommunicators */
166+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan);
167+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan);
168+
}
169+
170+
return OMPI_SUCCESS;
171+
}

ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt

Lines changed: 0 additions & 29 deletions
This file was deleted.

ompi/mca/coll/adapt/coll_adapt_module.c

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* Copyright (c) 2021 Triad National Security, LLC. All rights
66
* reserved.
77
* Copyright (c) 2022 IBM Corporation. All rights reserved
8+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
89
*
910
* $COPYRIGHT$
1011
*
@@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t,
8384
adapt_module_construct,
8485
adapt_module_destruct);
8586

86-
/*
87-
* In this macro, the following variables are supposed to have been declared
88-
* in the caller:
89-
* . ompi_communicator_t *comm
90-
* . mca_coll_adapt_module_t *adapt_module
91-
*/
92-
#define ADAPT_SAVE_PREV_COLL_API(__api) \
93-
do { \
94-
adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
95-
adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
96-
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
97-
opal_output_verbose(1, ompi_coll_base_framework.framework_output, \
98-
"(%s/%s): no underlying " # __api"; disqualifying myself", \
99-
ompi_comm_print_cid(comm), comm->c_name); \
100-
return OMPI_ERROR; \
101-
} \
102-
OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \
103-
} while(0)
104-
87+
#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \
88+
do \
89+
{ \
90+
if (__module->super.coll_##__api) \
91+
{ \
92+
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
93+
} \
94+
} while (0)
95+
#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \
96+
do \
97+
{ \
98+
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
99+
{ \
100+
MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \
101+
} \
102+
} while (0)
103+
#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \
104+
do \
105+
{ \
106+
if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \
107+
{ \
108+
MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
109+
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
110+
} \
111+
} while (0)
112+
#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \
113+
do \
114+
{ \
115+
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
116+
{ \
117+
MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
118+
__module->previous_##__api = NULL; \
119+
__module->previous_##__api##_module = NULL; \
120+
} \
121+
} while (0)
105122

106123
/*
107124
* Init module on the communicator
@@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module,
111128
{
112129
mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module;
113130

114-
ADAPT_SAVE_PREV_COLL_API(reduce);
115-
ADAPT_SAVE_PREV_COLL_API(ireduce);
131+
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce);
132+
ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast);
133+
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce);
134+
ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast);
116135

117136
return OMPI_SUCCESS;
118137
}
138+
static int adapt_module_disable(mca_coll_base_module_t *module,
139+
struct ompi_communicator_t *comm)
140+
{
141+
mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module;
119142

143+
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce);
144+
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast);
145+
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce);
146+
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast);
147+
148+
return OMPI_SUCCESS;
149+
}
120150
/*
121151
* Initial query function that is invoked during MPI_INIT, allowing
122152
* this component to disqualify itself if it doesn't support the
@@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t *
165195

166196
/* All is good -- return a module */
167197
adapt_module->super.coll_module_enable = adapt_module_enable;
168-
adapt_module->super.coll_allgather = NULL;
169-
adapt_module->super.coll_allgatherv = NULL;
170-
adapt_module->super.coll_allreduce = NULL;
171-
adapt_module->super.coll_alltoall = NULL;
172-
adapt_module->super.coll_alltoallw = NULL;
173-
adapt_module->super.coll_barrier = NULL;
198+
adapt_module->super.coll_module_disable = adapt_module_disable;
174199
adapt_module->super.coll_bcast = ompi_coll_adapt_bcast;
175-
adapt_module->super.coll_exscan = NULL;
176-
adapt_module->super.coll_gather = NULL;
177-
adapt_module->super.coll_gatherv = NULL;
178200
adapt_module->super.coll_reduce = ompi_coll_adapt_reduce;
179-
adapt_module->super.coll_reduce_scatter = NULL;
180-
adapt_module->super.coll_scan = NULL;
181-
adapt_module->super.coll_scatter = NULL;
182-
adapt_module->super.coll_scatterv = NULL;
183201
adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast;
184202
adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce;
185-
adapt_module->super.coll_iallreduce = NULL;
186203

187204
opal_output_verbose(10, ompi_coll_base_framework.framework_output,
188205
"coll:adapt:comm_query (%s/%s): pick me! pick me!",

0 commit comments

Comments
 (0)