|
2 | 2 | * Copyright (c) 2014-2017 The University of Tennessee and The University
|
3 | 3 | * of Tennessee Research Foundation. All rights
|
4 | 4 | * reserved.
|
5 |
| - * Copyright (c) 2014 NVIDIA Corporation. All rights reserved. |
| 5 | + * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. |
6 | 6 | * Copyright (c) 2019 Research Organization for Information Science
|
7 | 7 | * and Technology (RIST). All rights reserved.
|
8 | 8 | * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
|
|
32 | 32 | #include "ompi/mca/coll/base/base.h"
|
33 | 33 | #include "coll_accelerator.h"
|
34 | 34 |
|
| 35 | +static int |
| 36 | +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, |
| 37 | + struct ompi_communicator_t *comm); |
| 38 | +static int |
| 39 | +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, |
| 40 | + struct ompi_communicator_t *comm); |
35 | 41 |
|
36 | 42 | static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module)
|
37 | 43 | {
|
38 | 44 | memset(&(module->c_coll), 0, sizeof(module->c_coll));
|
39 | 45 | }
|
40 | 46 |
|
41 |
| -static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module) |
42 |
| -{ |
43 |
| - OBJ_RELEASE(module->c_coll.coll_allreduce_module); |
44 |
| - OBJ_RELEASE(module->c_coll.coll_reduce_module); |
45 |
| - OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module); |
46 |
| - OBJ_RELEASE(module->c_coll.coll_scatter_module); |
47 |
| - /* If the exscan module is not NULL, then this was an |
48 |
| - intracommunicator, and therefore scan will have a module as |
49 |
| - well. */ |
50 |
| - if (NULL != module->c_coll.coll_exscan_module) { |
51 |
| - OBJ_RELEASE(module->c_coll.coll_exscan_module); |
52 |
| - OBJ_RELEASE(module->c_coll.coll_scan_module); |
53 |
| - } |
54 |
| -} |
55 |
| - |
56 | 47 | OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t,
|
57 | 48 | mca_coll_accelerator_module_construct,
|
58 |
| - mca_coll_accelerator_module_destruct); |
| 49 | + NULL); |
59 | 50 |
|
60 | 51 |
|
61 | 52 | /*
|
@@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
|
99 | 90 |
|
100 | 91 | /* Choose whether to use [intra|inter] */
|
101 | 92 | accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable;
|
| 93 | + accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable; |
102 | 94 |
|
103 |
| - accelerator_module->super.coll_allgather = NULL; |
104 |
| - accelerator_module->super.coll_allgatherv = NULL; |
105 | 95 | accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce;
|
106 |
| - accelerator_module->super.coll_alltoall = NULL; |
107 |
| - accelerator_module->super.coll_alltoallv = NULL; |
108 |
| - accelerator_module->super.coll_alltoallw = NULL; |
109 |
| - accelerator_module->super.coll_barrier = NULL; |
110 |
| - accelerator_module->super.coll_bcast = NULL; |
111 |
| - accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; |
112 |
| - accelerator_module->super.coll_gather = NULL; |
113 |
| - accelerator_module->super.coll_gatherv = NULL; |
114 | 96 | accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce;
|
115 |
| - accelerator_module->super.coll_reduce_scatter = NULL; |
116 | 97 | accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block;
|
117 |
| - accelerator_module->super.coll_scan = mca_coll_accelerator_scan; |
118 |
| - accelerator_module->super.coll_scatter = NULL; |
119 |
| - accelerator_module->super.coll_scatterv = NULL; |
| 98 | + if (!OMPI_COMM_IS_INTER(comm)) { |
| 99 | + accelerator_module->super.coll_scan = mca_coll_accelerator_scan; |
| 100 | + accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; |
| 101 | + } |
120 | 102 |
|
121 | 103 | return &(accelerator_module->super);
|
122 | 104 | }
|
123 | 105 |
|
124 | 106 |
|
| 107 | +#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \ |
| 108 | + do \ |
| 109 | + { \ |
| 110 | + if ((__comm)->c_coll->coll_##__api) \ |
| 111 | + { \ |
| 112 | + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ |
| 113 | + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \ |
| 114 | + } \ |
| 115 | + else \ |
| 116 | + { \ |
| 117 | + opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ |
| 118 | + "cuda", #__api, ompi_process_info.nodename, \ |
| 119 | + mca_coll_accelerator_component.priority); \ |
| 120 | + } \ |
| 121 | + } while (0) |
| 122 | + |
| 123 | +#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \ |
| 124 | + do \ |
| 125 | + { \ |
| 126 | + if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \ |
| 127 | + MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ |
| 128 | + (__module)->c_coll.coll_##__api##_module = NULL; \ |
| 129 | + (__module)->c_coll.coll_##__api = NULL; \ |
| 130 | + } \ |
| 131 | + } while (0) |
| 132 | + |
125 | 133 | /*
|
126 |
| - * Init module on the communicator |
| 134 | + * Init/Fini module on the communicator |
127 | 135 | */
|
128 |
| -int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, |
129 |
| - struct ompi_communicator_t *comm) |
| 136 | +static int |
| 137 | +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, |
| 138 | + struct ompi_communicator_t *comm) |
130 | 139 | {
|
131 |
| - bool good = true; |
132 |
| - char *msg = NULL; |
133 | 140 | mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
|
134 | 141 |
|
135 |
| -#define CHECK_AND_RETAIN(src, dst, name) \ |
136 |
| - if (NULL == (src)->c_coll->coll_ ## name ## _module) { \ |
137 |
| - good = false; \ |
138 |
| - msg = #name; \ |
139 |
| - } else if (good) { \ |
140 |
| - (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \ |
141 |
| - (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \ |
142 |
| - OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \ |
143 |
| - } |
144 |
| - |
145 |
| - CHECK_AND_RETAIN(comm, s, allreduce); |
146 |
| - CHECK_AND_RETAIN(comm, s, reduce); |
147 |
| - CHECK_AND_RETAIN(comm, s, reduce_scatter_block); |
148 |
| - CHECK_AND_RETAIN(comm, s, scatter); |
| 142 | + ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); |
| 143 | + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); |
| 144 | + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); |
149 | 145 | if (!OMPI_COMM_IS_INTER(comm)) {
|
150 | 146 | /* MPI does not define scan/exscan on intercommunicators */
|
151 |
| - CHECK_AND_RETAIN(comm, s, exscan); |
152 |
| - CHECK_AND_RETAIN(comm, s, scan); |
| 147 | + ACCELERATOR_INSTALL_COLL_API(comm, s, exscan); |
| 148 | + ACCELERATOR_INSTALL_COLL_API(comm, s, scan); |
153 | 149 | }
|
154 | 150 |
|
155 |
| - /* All done */ |
156 |
| - if (good) { |
157 |
| - return OMPI_SUCCESS; |
158 |
| - } |
159 |
| - opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true, |
160 |
| - ompi_process_info.nodename, |
161 |
| - mca_coll_accelerator_component.priority, msg); |
162 |
| - return OMPI_ERR_NOT_FOUND; |
| 151 | + return OMPI_SUCCESS; |
163 | 152 | }
|
164 | 153 |
|
| 154 | +static int |
| 155 | +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, |
| 156 | + struct ompi_communicator_t *comm) |
| 157 | +{ |
| 158 | + mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; |
| 159 | + |
| 160 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce); |
| 161 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce); |
| 162 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block); |
| 163 | + if (!OMPI_COMM_IS_INTER(comm)) |
| 164 | + { |
| 165 | + /* MPI does not define scan/exscan on intercommunicators */ |
| 166 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan); |
| 167 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan); |
| 168 | + } |
| 169 | + |
| 170 | + return OMPI_SUCCESS; |
| 171 | +} |
0 commit comments