Skip to content

Commit 7d751a0

Browse files
authored
Merge pull request #9098 from wfaderhold21/topic/scoll_ucc_fix
scoll/ucc: fix group creation/destruction
2 parents 520d55b + 933d830 commit 7d751a0

File tree

3 files changed

+45
-32
lines changed

3 files changed

+45
-32
lines changed

oshmem/mca/scoll/ucc/scoll_ucc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ struct mca_scoll_ucc_component_t {
4242
int ucc_np;
4343
char * cls;
4444
char * cts;
45+
int nr_modules;
4546
bool libucc_initialized;
4647
ucc_lib_h ucc_lib;
4748
ucc_lib_attr_t ucc_lib_attr;

oshmem/mca/scoll/ucc/scoll_ucc_component.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ mca_scoll_ucc_component_t mca_scoll_ucc_component = {
6161
0, /* ucc_enable */
6262
2, /* ucc_np */
6363
"basic", /* cls */
64-
SCOLL_UCC_CTS_STR /* cts */
64+
SCOLL_UCC_CTS_STR, /* cts */
65+
0, /* nr_modules */
66+
false /* libucc_initialized */
6567
};
6668

6769
static int ucc_register(void)

oshmem/mca/scoll/ucc/scoll_ucc_module.c

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
#include "scoll_ucc.h"
1919
#include "scoll_ucc_debug.h"
2020

21-
#include "oshmem/mca/spml/spml.h"
22-
2321
#include <ucc/api/ucc.h>
2422

2523
#define OBJ_RELEASE_IF_NOT_NULL( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj );
@@ -51,17 +49,21 @@ int mca_scoll_ucc_progress(void)
5149

5250
static void mca_scoll_ucc_module_destruct(mca_scoll_ucc_module_t *ucc_module)
5351
{
54-
ucc_team_destroy(ucc_module->ucc_team);
52+
if (ucc_module->ucc_team) {
53+
ucc_team_destroy(ucc_module->ucc_team);
54+
--mca_scoll_ucc_component.nr_modules;
55+
}
5556

56-
if (ucc_module->group->ompi_comm == (ompi_communicator_t *) &oshmem_comm_world) {
57-
if (mca_scoll_ucc_component.libucc_initialized) {
57+
if (1 == mca_scoll_ucc_component.nr_modules) {
58+
if (mca_scoll_ucc_component.libucc_initialized) {
5859
UCC_VERBOSE(1, "finalizing ucc library");
5960
opal_progress_unregister(mca_scoll_ucc_progress);
6061
ucc_context_destroy(mca_scoll_ucc_component.ucc_context);
6162
ucc_finalize(mca_scoll_ucc_component.ucc_lib);
63+
mca_scoll_ucc_component.libucc_initialized = false;
6264
}
63-
}
64-
65+
}
66+
6567
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_alltoall_module);
6668
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_collect_module);
6769
OBJ_RELEASE_IF_NOT_NULL(ucc_module->previous_reduce_module);
@@ -142,15 +144,14 @@ static inline ucc_status_t oob_probe_test(oob_allgather_req_t *oob_req)
142144
static ucc_status_t oob_allgather_test(void *req)
143145
{
144146
oob_allgather_req_t *oob_req = (oob_allgather_req_t*) req;
145-
oshmem_group_t *osh_group = (oshmem_group_t *) oob_req->oob_coll_ctx;
146-
ompi_communicator_t *comm = osh_group->ompi_comm;
147+
ompi_communicator_t *comm = (ompi_communicator_t *) oob_req->oob_coll_ctx;
147148
char *tmpsend = NULL;
148149
char *tmprecv = NULL;
149150
size_t msglen = oob_req->msglen;
150151
int rank, size, sendto, recvfrom, recvdatafrom, senddatafrom;
151152

152-
size = osh_group->proc_count;
153-
rank = osh_group->my_pe;
153+
rank = ompi_comm_rank(comm);
154+
size = ompi_comm_size(comm);
154155

155156
if (0 == oob_req->iter) {
156157
tmprecv = (char *)oob_req->rbuf + (ptrdiff_t)rank * (ptrdiff_t)msglen;
@@ -229,8 +230,9 @@ static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
229230
ctx_params.oob.allgather = oob_allgather;
230231
ctx_params.oob.req_test = oob_allgather_test;
231232
ctx_params.oob.req_free = oob_allgather_free;
232-
ctx_params.oob.coll_info = (void *) osh_group;
233-
ctx_params.oob.participants = osh_group->proc_count;
233+
ctx_params.oob.coll_info = (void *) oshmem_comm_world;
234+
ctx_params.oob.participants = ompi_comm_size(oshmem_comm_world);
235+
234236
if (UCC_OK != ucc_context_config_read(cm->ucc_lib, NULL, &ctx_config)) {
235237
UCC_ERROR("UCC context config read failed");
236238
goto cleanup_lib;
@@ -278,7 +280,9 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
278280
{
279281
mca_scoll_ucc_component_t *cm = &mca_scoll_ucc_component;
280282
mca_scoll_ucc_module_t *ucc_module = (mca_scoll_ucc_module_t *) module;
281-
ucc_status_t status;
283+
ucc_status_t status = UCC_OK;
284+
285+
ucc_module->ucc_team = NULL;
282286

283287
ucc_team_params_t team_params = {
284288
.mask = UCC_TEAM_PARAM_FIELD_EP |
@@ -288,10 +292,10 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
288292
.allgather = oob_allgather,
289293
.req_test = oob_allgather_test,
290294
.req_free = oob_allgather_free,
291-
.coll_info = (void *)osh_group,
292-
.participants = osh_group->proc_count,
295+
.coll_info = (void *)osh_group->ompi_comm,
296+
.participants = ompi_comm_size(osh_group->ompi_comm),
293297
},
294-
.ep = osh_group->my_pe,
298+
.ep = ompi_comm_rank(osh_group->ompi_comm),
295299
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG,
296300
};
297301

@@ -304,15 +308,18 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
304308

305309
return OSHMEM_ERROR;
306310
}
311+
312+
++cm->nr_modules;
313+
if (cm->ucc_context) {
314+
if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1,
315+
&team_params, &ucc_module->ucc_team)) {
316+
UCC_ERROR("ucc_team_create_post failed");
317+
}
307318

308-
if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1,
309-
&team_params, &ucc_module->ucc_team)) {
310-
UCC_ERROR("ucc_team_create_post failed");
311-
}
312-
313-
while (UCC_INPROGRESS == (status = ucc_team_create_test(ucc_module->ucc_team))) {
314-
opal_progress();
315-
}
319+
while (UCC_INPROGRESS == (status = ucc_team_create_test(ucc_module->ucc_team))) {
320+
opal_progress();
321+
}
322+
}
316323

317324
if (UCC_OK != status) {
318325
UCC_ERROR("ucc_team_create_test failed");
@@ -347,10 +354,11 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
347354
mca_scoll_base_module_t *
348355
mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority)
349356
{
350-
mca_scoll_base_module_t *module;
351-
mca_scoll_ucc_module_t *ucc_module;
352-
*priority = 0;
357+
mca_scoll_base_module_t *module;
358+
mca_scoll_ucc_module_t *ucc_module;
353359
mca_scoll_ucc_component_t *cm;
360+
361+
*priority = 0;
354362
cm = &mca_scoll_ucc_component;
355363

356364
if (!cm->ucc_enable) {
@@ -363,9 +371,11 @@ mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority)
363371
OPAL_TIMING_ENV_INIT(comm_query);
364372

365373
if (!cm->libucc_initialized) {
366-
if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx(osh_group)) {
367-
cm->ucc_enable = 0;
368-
return NULL;
374+
if (0 < cm->nr_modules) {
375+
if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx(osh_group)) {
376+
cm->ucc_enable = 0;
377+
return NULL;
378+
}
369379
}
370380
}
371381

0 commit comments

Comments
 (0)