18
18
#include "scoll_ucc.h"
19
19
#include "scoll_ucc_debug.h"
20
20
21
- #include "oshmem/mca/spml/spml.h"
22
-
23
21
#include <ucc/api/ucc.h>
24
22
25
23
#define OBJ_RELEASE_IF_NOT_NULL ( obj ) if( NULL != (obj) ) OBJ_RELEASE( obj );
@@ -51,17 +49,21 @@ int mca_scoll_ucc_progress(void)
51
49
52
50
static void mca_scoll_ucc_module_destruct (mca_scoll_ucc_module_t * ucc_module )
53
51
{
54
- ucc_team_destroy (ucc_module -> ucc_team );
52
+ if (ucc_module -> ucc_team ) {
53
+ ucc_team_destroy (ucc_module -> ucc_team );
54
+ -- mca_scoll_ucc_component .nr_modules ;
55
+ }
55
56
56
- if (ucc_module -> group -> ompi_comm == ( ompi_communicator_t * ) & oshmem_comm_world ) {
57
- if (mca_scoll_ucc_component .libucc_initialized ) {
57
+ if (1 == mca_scoll_ucc_component . nr_modules ) {
58
+ if (mca_scoll_ucc_component .libucc_initialized ) {
58
59
UCC_VERBOSE (1 , "finalizing ucc library" );
59
60
opal_progress_unregister (mca_scoll_ucc_progress );
60
61
ucc_context_destroy (mca_scoll_ucc_component .ucc_context );
61
62
ucc_finalize (mca_scoll_ucc_component .ucc_lib );
63
+ mca_scoll_ucc_component .libucc_initialized = false;
62
64
}
63
- }
64
-
65
+ }
66
+
65
67
OBJ_RELEASE_IF_NOT_NULL (ucc_module -> previous_alltoall_module );
66
68
OBJ_RELEASE_IF_NOT_NULL (ucc_module -> previous_collect_module );
67
69
OBJ_RELEASE_IF_NOT_NULL (ucc_module -> previous_reduce_module );
@@ -142,15 +144,14 @@ static inline ucc_status_t oob_probe_test(oob_allgather_req_t *oob_req)
142
144
static ucc_status_t oob_allgather_test (void * req )
143
145
{
144
146
oob_allgather_req_t * oob_req = (oob_allgather_req_t * ) req ;
145
- oshmem_group_t * osh_group = (oshmem_group_t * ) oob_req -> oob_coll_ctx ;
146
- ompi_communicator_t * comm = osh_group -> ompi_comm ;
147
+ ompi_communicator_t * comm = (ompi_communicator_t * ) oob_req -> oob_coll_ctx ;
147
148
char * tmpsend = NULL ;
148
149
char * tmprecv = NULL ;
149
150
size_t msglen = oob_req -> msglen ;
150
151
int rank , size , sendto , recvfrom , recvdatafrom , senddatafrom ;
151
152
152
- size = osh_group -> proc_count ;
153
- rank = osh_group -> my_pe ;
153
+ rank = ompi_comm_rank ( comm ) ;
154
+ size = ompi_comm_size ( comm );
154
155
155
156
if (0 == oob_req -> iter ) {
156
157
tmprecv = (char * )oob_req -> rbuf + (ptrdiff_t )rank * (ptrdiff_t )msglen ;
@@ -229,8 +230,9 @@ static int mca_scoll_ucc_init_ctx(oshmem_group_t *osh_group)
229
230
ctx_params .oob .allgather = oob_allgather ;
230
231
ctx_params .oob .req_test = oob_allgather_test ;
231
232
ctx_params .oob .req_free = oob_allgather_free ;
232
- ctx_params .oob .coll_info = (void * ) osh_group ;
233
- ctx_params .oob .participants = osh_group -> proc_count ;
233
+ ctx_params .oob .coll_info = (void * ) oshmem_comm_world ;
234
+ ctx_params .oob .participants = ompi_comm_size (oshmem_comm_world );
235
+
234
236
if (UCC_OK != ucc_context_config_read (cm -> ucc_lib , NULL , & ctx_config )) {
235
237
UCC_ERROR ("UCC context config read failed" );
236
238
goto cleanup_lib ;
@@ -278,7 +280,9 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
278
280
{
279
281
mca_scoll_ucc_component_t * cm = & mca_scoll_ucc_component ;
280
282
mca_scoll_ucc_module_t * ucc_module = (mca_scoll_ucc_module_t * ) module ;
281
- ucc_status_t status ;
283
+ ucc_status_t status = UCC_OK ;
284
+
285
+ ucc_module -> ucc_team = NULL ;
282
286
283
287
ucc_team_params_t team_params = {
284
288
.mask = UCC_TEAM_PARAM_FIELD_EP |
@@ -288,10 +292,10 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
288
292
.allgather = oob_allgather ,
289
293
.req_test = oob_allgather_test ,
290
294
.req_free = oob_allgather_free ,
291
- .coll_info = (void * )osh_group ,
292
- .participants = osh_group -> proc_count ,
295
+ .coll_info = (void * )osh_group -> ompi_comm ,
296
+ .participants = ompi_comm_size ( osh_group -> ompi_comm ) ,
293
297
},
294
- .ep = osh_group -> my_pe ,
298
+ .ep = ompi_comm_rank ( osh_group -> ompi_comm ) ,
295
299
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG ,
296
300
};
297
301
@@ -304,15 +308,18 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
304
308
305
309
return OSHMEM_ERROR ;
306
310
}
311
+
312
+ ++ cm -> nr_modules ;
313
+ if (cm -> ucc_context ) {
314
+ if (UCC_OK != ucc_team_create_post (& cm -> ucc_context , 1 ,
315
+ & team_params , & ucc_module -> ucc_team )) {
316
+ UCC_ERROR ("ucc_team_create_post failed" );
317
+ }
307
318
308
- if (UCC_OK != ucc_team_create_post (& cm -> ucc_context , 1 ,
309
- & team_params , & ucc_module -> ucc_team )) {
310
- UCC_ERROR ("ucc_team_create_post failed" );
311
- }
312
-
313
- while (UCC_INPROGRESS == (status = ucc_team_create_test (ucc_module -> ucc_team ))) {
314
- opal_progress ();
315
- }
319
+ while (UCC_INPROGRESS == (status = ucc_team_create_test (ucc_module -> ucc_team ))) {
320
+ opal_progress ();
321
+ }
322
+ }
316
323
317
324
if (UCC_OK != status ) {
318
325
UCC_ERROR ("ucc_team_create_test failed" );
@@ -347,10 +354,11 @@ static int mca_scoll_ucc_module_enable(mca_scoll_base_module_t *module,
347
354
mca_scoll_base_module_t *
348
355
mca_scoll_ucc_comm_query (oshmem_group_t * osh_group , int * priority )
349
356
{
350
- mca_scoll_base_module_t * module ;
351
- mca_scoll_ucc_module_t * ucc_module ;
352
- * priority = 0 ;
357
+ mca_scoll_base_module_t * module ;
358
+ mca_scoll_ucc_module_t * ucc_module ;
353
359
mca_scoll_ucc_component_t * cm ;
360
+
361
+ * priority = 0 ;
354
362
cm = & mca_scoll_ucc_component ;
355
363
356
364
if (!cm -> ucc_enable ) {
@@ -363,9 +371,11 @@ mca_scoll_ucc_comm_query(oshmem_group_t *osh_group, int *priority)
363
371
OPAL_TIMING_ENV_INIT (comm_query );
364
372
365
373
if (!cm -> libucc_initialized ) {
366
- if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx (osh_group )) {
367
- cm -> ucc_enable = 0 ;
368
- return NULL ;
374
+ if (0 < cm -> nr_modules ) {
375
+ if (OSHMEM_SUCCESS != mca_scoll_ucc_init_ctx (osh_group )) {
376
+ cm -> ucc_enable = 0 ;
377
+ return NULL ;
378
+ }
369
379
}
370
380
}
371
381
0 commit comments