@@ -79,6 +79,7 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = {
79
79
},
80
80
.wpool = NULL ,
81
81
.env_initialized = false,
82
+ .priority_is_set = false,
82
83
.num_modules = 0 ,
83
84
.acc_single_intrinsic = false,
84
85
.comm_world_size = 0 ,
@@ -280,17 +281,32 @@ static int ucp_context_init(bool enable_mt, int proc_world_size) {
280
281
}
281
282
282
283
static int component_init (bool enable_progress_threads , bool enable_mpi_threads ) {
283
- opal_common_ucx_support_level_t support_level = OPAL_COMMON_UCX_SUPPORT_NONE ;
284
- mca_base_var_source_t param_source = MCA_BASE_VAR_SOURCE_DEFAULT ;
285
- int ret = OMPI_SUCCESS ,
286
- param = -1 ;
287
284
288
285
mca_osc_ucx_component .enable_mpi_threads = enable_mpi_threads ;
289
286
mca_osc_ucx_component .wpool = opal_common_ucx_wpool_allocate ();
287
+ mca_osc_ucx_component .priority_is_set = false;
290
288
291
- ret = ucp_context_init (enable_mpi_threads , ompi_proc_world_size ());
292
- if (OMPI_ERROR == ret ) {
293
- return OMPI_ERR_NOT_AVAILABLE ;
289
+ return OMPI_SUCCESS ;
290
+ }
291
+
292
+ static int component_set_priority () {
293
+ int param , ret ;
294
+ opal_common_ucx_support_level_t support_level = OPAL_COMMON_UCX_SUPPORT_NONE ;
295
+ mca_base_var_source_t param_source = MCA_BASE_VAR_SOURCE_DEFAULT ;
296
+
297
+ if (mca_osc_ucx_component .priority_is_set == true) {
298
+ return OMPI_SUCCESS ;
299
+ }
300
+
301
+ if (mca_osc_ucx_component .wpool == NULL ) {
302
+ mca_osc_ucx_component .wpool = opal_common_ucx_wpool_allocate ();
303
+ }
304
+
305
+ if (mca_osc_ucx_component .wpool -> ucp_ctx == NULL ) {
306
+ ret = ucp_context_init (mca_osc_ucx_component .enable_mpi_threads , ompi_proc_world_size ());
307
+ if (OMPI_ERROR == ret ) {
308
+ return OMPI_ERR_NOT_AVAILABLE ;
309
+ }
294
310
}
295
311
296
312
support_level = opal_common_ucx_support_level (mca_osc_ucx_component .wpool -> ucp_ctx );
@@ -315,6 +331,8 @@ static int component_init(bool enable_progress_threads, bool enable_mpi_threads)
315
331
}
316
332
OSC_UCX_VERBOSE (2 , "returning priority %d" , mca_osc_ucx_component .priority );
317
333
334
+ mca_osc_ucx_component .priority_is_set = true;
335
+
318
336
return OMPI_SUCCESS ;
319
337
}
320
338
@@ -344,6 +362,14 @@ static int component_finalize(void) {
344
362
345
363
static int component_query (struct ompi_win_t * win , void * * base , size_t size , int disp_unit ,
346
364
struct ompi_communicator_t * comm , struct opal_info_t * info , int flavor ) {
365
+ int ret ;
366
+ if (mca_osc_ucx_component .priority_is_set == false) {
367
+ ret = component_set_priority ();
368
+ if (OMPI_SUCCESS != ret ) {
369
+ OSC_UCX_ERROR ("OSC UCX component priority set inside component query failed \n " );
370
+ return ret ;
371
+ }
372
+ }
347
373
return mca_osc_ucx_component .priority ;
348
374
}
349
375
@@ -507,6 +533,14 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
507
533
* we don't want to initialize in the component_init()
508
534
*/
509
535
536
+ if (mca_osc_ucx_component .priority_is_set == false) {
537
+ ret = component_set_priority ();
538
+ if (OMPI_SUCCESS != ret ) {
539
+ OSC_UCX_ERROR ("OSC UCX component priority set inside component select failed \n " );
540
+ return ret ;
541
+ }
542
+ }
543
+
510
544
OBJ_CONSTRUCT (& mca_osc_ucx_component .requests , opal_free_list_t );
511
545
ret = opal_free_list_init (& mca_osc_ucx_component .requests ,
512
546
sizeof (ompi_osc_ucx_generic_request_t ),
0 commit comments