Skip to content

Commit bf2068a

Browse files
authored
Merge pull request #12429 from bosilca/topic/fix_collective_init_fini
Fix collective modules initialization and finalization
2 parents 49382c3 + f2dfbba commit bf2068a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1734
-1421
lines changed

ompi/mca/coll/accelerator/Makefile.am

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
# Copyright (c) 2014 The University of Tennessee and The University
33
# of Tennessee Research Foundation. All rights
44
# reserved.
5-
# Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
5+
# Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
66
# Copyright (c) 2017 IBM Corporation. All rights reserved.
77
# $COPYRIGHT$
88
#
99
# Additional copyrights may follow
1010
#
1111
# $HEADER$
1212
#
13-
dist_ompidata_DATA = help-mpi-coll-accelerator.txt
1413

1514
sources = coll_accelerator_module.c coll_accelerator_reduce.c coll_accelerator_allreduce.c \
1615
coll_accelerator_reduce_scatter_block.c coll_accelerator_component.c \

ompi/mca/coll/accelerator/coll_accelerator.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Copyright (c) 2014 The University of Tennessee and The University
33
* of Tennessee Research Foundation. All rights
44
* reserved.
5-
* Copyright (c) 2014-2015 NVIDIA Corporation. All rights reserved.
5+
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
66
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
77
* $COPYRIGHT$
88
*
@@ -38,9 +38,6 @@ mca_coll_base_module_t
3838
*mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
3939
int *priority);
4040

41-
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
42-
struct ompi_communicator_t *comm);
43-
4441
int
4542
mca_coll_accelerator_allreduce(const void *sbuf, void *rbuf, int count,
4643
struct ompi_datatype_t *dtype,

ompi/mca/coll/accelerator/coll_accelerator_module.c

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Copyright (c) 2014-2017 The University of Tennessee and The University
33
* of Tennessee Research Foundation. All rights
44
* reserved.
5-
* Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
5+
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
66
* Copyright (c) 2019 Research Organization for Information Science
77
* and Technology (RIST). All rights reserved.
88
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
@@ -32,30 +32,21 @@
3232
#include "ompi/mca/coll/base/base.h"
3333
#include "coll_accelerator.h"
3434

35+
static int
36+
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
37+
struct ompi_communicator_t *comm);
38+
static int
39+
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
40+
struct ompi_communicator_t *comm);
3541

3642
static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
3743
{
3844
memset(&(module->c_coll), 0, sizeof(module->c_coll));
3945
}
4046

41-
static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module)
42-
{
43-
OBJ_RELEASE(module->c_coll.coll_allreduce_module);
44-
OBJ_RELEASE(module->c_coll.coll_reduce_module);
45-
OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module);
46-
OBJ_RELEASE(module->c_coll.coll_scatter_module);
47-
/* If the exscan module is not NULL, then this was an
48-
intracommunicator, and therefore scan will have a module as
49-
well. */
50-
if (NULL != module->c_coll.coll_exscan_module) {
51-
OBJ_RELEASE(module->c_coll.coll_exscan_module);
52-
OBJ_RELEASE(module->c_coll.coll_scan_module);
53-
}
54-
}
55-
5647
OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
5748
mca_coll_accelerator_module_construct,
58-
mca_coll_accelerator_module_destruct);
49+
NULL);
5950

6051

6152
/*
@@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
9990

10091
/* Choose whether to use [intra|inter] */
10192
accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
93+
accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable;
10294

103-
accelerator_module->super.coll_allgather = NULL;
104-
accelerator_module->super.coll_allgatherv = NULL;
10595
accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
106-
accelerator_module->super.coll_alltoall = NULL;
107-
accelerator_module->super.coll_alltoallv = NULL;
108-
accelerator_module->super.coll_alltoallw = NULL;
109-
accelerator_module->super.coll_barrier = NULL;
110-
accelerator_module->super.coll_bcast = NULL;
111-
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
112-
accelerator_module->super.coll_gather = NULL;
113-
accelerator_module->super.coll_gatherv = NULL;
11496
accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
115-
accelerator_module->super.coll_reduce_scatter = NULL;
11697
accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
117-
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
118-
accelerator_module->super.coll_scatter = NULL;
119-
accelerator_module->super.coll_scatterv = NULL;
98+
if (!OMPI_COMM_IS_INTER(comm)) {
99+
accelerator_module->super.coll_scan = mca_coll_accelerator_scan;
100+
accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan;
101+
}
120102

121103
return &(accelerator_module->super);
122104
}
123105

124106

107+
#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
108+
do \
109+
{ \
110+
if ((__comm)->c_coll->coll_##__api) \
111+
{ \
112+
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
113+
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
114+
} \
115+
else \
116+
{ \
117+
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
118+
"cuda", #__api, ompi_process_info.nodename, \
119+
mca_coll_accelerator_component.priority); \
120+
} \
121+
} while (0)
122+
123+
#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \
124+
do \
125+
{ \
126+
if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \
127+
MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
128+
(__module)->c_coll.coll_##__api##_module = NULL; \
129+
(__module)->c_coll.coll_##__api = NULL; \
130+
} \
131+
} while (0)
132+
125133
/*
126-
* Init module on the communicator
134+
* Init/Fini module on the communicator
127135
*/
128-
int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
129-
struct ompi_communicator_t *comm)
136+
static int
137+
mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
138+
struct ompi_communicator_t *comm)
130139
{
131-
bool good = true;
132-
char *msg = NULL;
133140
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
134141

135-
#define CHECK_AND_RETAIN(src, dst, name) \
136-
if (NULL == (src)->c_coll->coll_ ## name ## _module) { \
137-
good = false; \
138-
msg = #name; \
139-
} else if (good) { \
140-
(dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \
141-
(dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \
142-
OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \
143-
}
144-
145-
CHECK_AND_RETAIN(comm, s, allreduce);
146-
CHECK_AND_RETAIN(comm, s, reduce);
147-
CHECK_AND_RETAIN(comm, s, reduce_scatter_block);
148-
CHECK_AND_RETAIN(comm, s, scatter);
142+
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
143+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
144+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
149145
if (!OMPI_COMM_IS_INTER(comm)) {
150146
/* MPI does not define scan/exscan on intercommunicators */
151-
CHECK_AND_RETAIN(comm, s, exscan);
152-
CHECK_AND_RETAIN(comm, s, scan);
147+
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
148+
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
153149
}
154150

155-
/* All done */
156-
if (good) {
157-
return OMPI_SUCCESS;
158-
}
159-
opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true,
160-
ompi_process_info.nodename,
161-
mca_coll_accelerator_component.priority, msg);
162-
return OMPI_ERR_NOT_FOUND;
151+
return OMPI_SUCCESS;
163152
}
164153

154+
static int
155+
mca_coll_accelerator_module_disable(mca_coll_base_module_t *module,
156+
struct ompi_communicator_t *comm)
157+
{
158+
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
159+
160+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce);
161+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce);
162+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block);
163+
if (!OMPI_COMM_IS_INTER(comm))
164+
{
165+
/* MPI does not define scan/exscan on intercommunicators */
166+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan);
167+
ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan);
168+
}
169+
170+
return OMPI_SUCCESS;
171+
}

ompi/mca/coll/accelerator/help-mpi-coll-accelerator.txt

Lines changed: 0 additions & 29 deletions
This file was deleted.

ompi/mca/coll/adapt/coll_adapt_module.c

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* Copyright (c) 2021 Triad National Security, LLC. All rights
66
* reserved.
77
* Copyright (c) 2022 IBM Corporation. All rights reserved
8+
* Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
89
*
910
* $COPYRIGHT$
1011
*
@@ -83,25 +84,41 @@ OBJ_CLASS_INSTANCE(mca_coll_adapt_module_t,
8384
adapt_module_construct,
8485
adapt_module_destruct);
8586

86-
/*
87-
* In this macro, the following variables are supposed to have been declared
88-
* in the caller:
89-
* . ompi_communicator_t *comm
90-
* . mca_coll_adapt_module_t *adapt_module
91-
*/
92-
#define ADAPT_SAVE_PREV_COLL_API(__api) \
93-
do { \
94-
adapt_module->previous_ ## __api = comm->c_coll->coll_ ## __api; \
95-
adapt_module->previous_ ## __api ## _module = comm->c_coll->coll_ ## __api ## _module; \
96-
if (!comm->c_coll->coll_ ## __api || !comm->c_coll->coll_ ## __api ## _module) { \
97-
opal_output_verbose(1, ompi_coll_base_framework.framework_output, \
98-
"(%s/%s): no underlying " # __api"; disqualifying myself", \
99-
ompi_comm_print_cid(comm), comm->c_name); \
100-
return OMPI_ERROR; \
101-
} \
102-
OBJ_RETAIN(adapt_module->previous_ ## __api ## _module); \
103-
} while(0)
104-
87+
#define ADAPT_INSTALL_COLL_API(__comm, __module, __api) \
88+
do \
89+
{ \
90+
if (__module->super.coll_##__api) \
91+
{ \
92+
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
93+
} \
94+
} while (0)
95+
#define ADAPT_UNINSTALL_COLL_API(__comm, __module, __api) \
96+
do \
97+
{ \
98+
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
99+
{ \
100+
MCA_COLL_INSTALL_API(__comm, __api, NULL, NULL, "adapt"); \
101+
} \
102+
} while (0)
103+
#define ADAPT_INSTALL_AND_SAVE_COLL_API(__comm, __module, __api) \
104+
do \
105+
{ \
106+
if (__comm->c_coll->coll_##__api && __comm->c_coll->coll_##__api##_module) \
107+
{ \
108+
MCA_COLL_SAVE_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
109+
MCA_COLL_INSTALL_API(__comm, __api, __module->super.coll_##__api, &__module->super, "adapt"); \
110+
} \
111+
} while (0)
112+
#define ADAPT_UNINSTALL_AND_RESTORE_COLL_API(__comm, __module, __api) \
113+
do \
114+
{ \
115+
if (__comm->c_coll->coll_##__api##_module == &__module->super) \
116+
{ \
117+
MCA_COLL_INSTALL_API(__comm, __api, __module->previous_##__api, __module->previous_##__api##_module, "adapt"); \
118+
__module->previous_##__api = NULL; \
119+
__module->previous_##__api##_module = NULL; \
120+
} \
121+
} while (0)
105122

106123
/*
107124
* Init module on the communicator
@@ -111,12 +128,25 @@ static int adapt_module_enable(mca_coll_base_module_t * module,
111128
{
112129
mca_coll_adapt_module_t * adapt_module = (mca_coll_adapt_module_t*) module;
113130

114-
ADAPT_SAVE_PREV_COLL_API(reduce);
115-
ADAPT_SAVE_PREV_COLL_API(ireduce);
131+
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, reduce);
132+
ADAPT_INSTALL_COLL_API(comm, adapt_module, bcast);
133+
ADAPT_INSTALL_AND_SAVE_COLL_API(comm, adapt_module, ireduce);
134+
ADAPT_INSTALL_COLL_API(comm, adapt_module, ibcast);
116135

117136
return OMPI_SUCCESS;
118137
}
138+
static int adapt_module_disable(mca_coll_base_module_t *module,
139+
struct ompi_communicator_t *comm)
140+
{
141+
mca_coll_adapt_module_t *adapt_module = (mca_coll_adapt_module_t *)module;
119142

143+
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, reduce);
144+
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, bcast);
145+
ADAPT_UNINSTALL_AND_RESTORE_COLL_API(comm, adapt_module, ireduce);
146+
ADAPT_UNINSTALL_COLL_API(comm, adapt_module, ibcast);
147+
148+
return OMPI_SUCCESS;
149+
}
120150
/*
121151
* Initial query function that is invoked during MPI_INIT, allowing
122152
* this component to disqualify itself if it doesn't support the
@@ -165,24 +195,11 @@ mca_coll_base_module_t *ompi_coll_adapt_comm_query(struct ompi_communicator_t *
165195

166196
/* All is good -- return a module */
167197
adapt_module->super.coll_module_enable = adapt_module_enable;
168-
adapt_module->super.coll_allgather = NULL;
169-
adapt_module->super.coll_allgatherv = NULL;
170-
adapt_module->super.coll_allreduce = NULL;
171-
adapt_module->super.coll_alltoall = NULL;
172-
adapt_module->super.coll_alltoallw = NULL;
173-
adapt_module->super.coll_barrier = NULL;
198+
adapt_module->super.coll_module_disable = adapt_module_disable;
174199
adapt_module->super.coll_bcast = ompi_coll_adapt_bcast;
175-
adapt_module->super.coll_exscan = NULL;
176-
adapt_module->super.coll_gather = NULL;
177-
adapt_module->super.coll_gatherv = NULL;
178200
adapt_module->super.coll_reduce = ompi_coll_adapt_reduce;
179-
adapt_module->super.coll_reduce_scatter = NULL;
180-
adapt_module->super.coll_scan = NULL;
181-
adapt_module->super.coll_scatter = NULL;
182-
adapt_module->super.coll_scatterv = NULL;
183201
adapt_module->super.coll_ibcast = ompi_coll_adapt_ibcast;
184202
adapt_module->super.coll_ireduce = ompi_coll_adapt_ireduce;
185-
adapt_module->super.coll_iallreduce = NULL;
186203

187204
opal_output_verbose(10, ompi_coll_base_framework.framework_output,
188205
"coll:adapt:comm_query (%s/%s): pick me! pick me!",

0 commit comments

Comments
 (0)