@@ -26,6 +26,7 @@ static int component_query(struct ompi_win_t *win, void **base, size_t size, int
26
26
static int component_select (struct ompi_win_t * win , void * * base , size_t size , int disp_unit ,
27
27
struct ompi_communicator_t * comm , struct opal_info_t * info ,
28
28
int flavor , int * model );
29
+ static void ompi_osc_ucx_unregister_progress (void );
29
30
30
31
ompi_osc_ucx_component_t mca_osc_ucx_component = {
31
32
{ /* ompi_osc_base_component_t */
@@ -45,7 +46,12 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = {
45
46
.osc_query = component_query ,
46
47
.osc_select = component_select ,
47
48
.osc_finalize = component_finalize ,
48
- }
49
+ },
50
+ .ucp_context = NULL ,
51
+ .ucp_worker = NULL ,
52
+ .env_initialized = false,
53
+ .num_incomplete_req_ops = 0 ,
54
+ .num_modules = 0
49
55
};
50
56
51
57
ompi_osc_ucx_module_t ompi_osc_ucx_module_template = {
@@ -105,24 +111,15 @@ static int component_register(void) {
105
111
}
106
112
107
113
static int progress_callback (void ) {
108
- if (mca_osc_ucx_component .ucp_worker != NULL &&
109
- mca_osc_ucx_component .num_incomplete_req_ops > 0 ) {
110
- ucp_worker_progress (mca_osc_ucx_component .ucp_worker );
111
- }
114
+ ucp_worker_progress (mca_osc_ucx_component .ucp_worker );
112
115
return 0 ;
113
116
}
114
117
115
118
static int component_init (bool enable_progress_threads , bool enable_mpi_threads ) {
116
- int ret = OMPI_SUCCESS ;
117
-
118
- mca_osc_ucx_component .ucp_context = NULL ;
119
- mca_osc_ucx_component .ucp_worker = NULL ;
120
119
mca_osc_ucx_component .enable_mpi_threads = enable_mpi_threads ;
121
- mca_osc_ucx_component .env_initialized = false;
122
- mca_osc_ucx_component .num_incomplete_req_ops = 0 ;
123
120
124
121
opal_common_ucx_mca_register ();
125
- return ret ;
122
+ return OMPI_SUCCESS ;
126
123
}
127
124
128
125
static int component_finalize (void ) {
@@ -141,7 +138,6 @@ static int component_finalize(void) {
141
138
assert (mca_osc_ucx_component .num_incomplete_req_ops == 0 );
142
139
if (mca_osc_ucx_component .env_initialized == true) {
143
140
OBJ_DESTRUCT (& mca_osc_ucx_component .requests );
144
- opal_progress_unregister (progress_callback );
145
141
ucp_cleanup (mca_osc_ucx_component .ucp_context );
146
142
mca_osc_ucx_component .env_initialized = false;
147
143
}
@@ -241,6 +237,20 @@ static inline int mem_map(void **base, size_t size, ucp_mem_h *memh_ptr,
241
237
return ret ;
242
238
}
243
239
240
+ static void ompi_osc_ucx_unregister_progress ()
241
+ {
242
+ int ret ;
243
+
244
+ mca_osc_ucx_component .num_modules -- ;
245
+ OSC_UCX_ASSERT (mca_osc_ucx_component .num_modules >= 0 );
246
+ if (0 == mca_osc_ucx_component .num_modules ) {
247
+ ret = opal_progress_unregister (progress_callback );
248
+ if (OMPI_SUCCESS != ret ) {
249
+ OSC_UCX_VERBOSE (1 , "opal_progress_unregister failed: %d" , ret );
250
+ }
251
+ }
252
+ }
253
+
244
254
static int component_select (struct ompi_win_t * win , void * * base , size_t size , int disp_unit ,
245
255
struct ompi_communicator_t * comm , struct opal_info_t * info ,
246
256
int flavor , int * model ) {
@@ -251,7 +261,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
251
261
ucs_status_t status ;
252
262
int i , comm_size = ompi_comm_size (comm );
253
263
int is_eps_ready ;
254
- bool progress_registered = false, eps_created = false, env_initialized = false;
264
+ bool eps_created = false, env_initialized = false;
255
265
ucp_address_t * my_addr = NULL ;
256
266
size_t my_addr_len ;
257
267
char * recv_buf = NULL ;
@@ -328,13 +338,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
328
338
goto error_nomem ;
329
339
}
330
340
331
- ret = opal_progress_register (progress_callback );
332
- progress_registered = true;
333
- if (OMPI_SUCCESS != ret ) {
334
- OSC_UCX_VERBOSE (1 , "opal_progress_register failed: %d" , ret );
335
- goto error ;
336
- }
337
-
338
341
/* query UCP worker attributes */
339
342
worker_attr .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE ;
340
343
status = ucp_worker_query (mca_osc_ucx_component .ucp_worker , & worker_attr );
@@ -362,6 +365,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
362
365
goto error_nomem ;
363
366
}
364
367
368
+ mca_osc_ucx_component .num_modules ++ ;
369
+
365
370
/* fill in the function pointer part */
366
371
memcpy (module , & ompi_osc_ucx_module_template , sizeof (ompi_osc_base_module_t ));
367
372
@@ -616,6 +621,14 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
616
621
goto error ;
617
622
}
618
623
624
+ OSC_UCX_ASSERT (mca_osc_ucx_component .num_modules > 0 );
625
+ if (1 == mca_osc_ucx_component .num_modules ) {
626
+ ret = opal_progress_register (progress_callback );
627
+ if (OMPI_SUCCESS != ret ) {
628
+ OSC_UCX_VERBOSE (1 , "opal_progress_register failed: %d" , ret );
629
+ goto error ;
630
+ }
631
+ }
619
632
return ret ;
620
633
621
634
error :
@@ -643,8 +656,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
643
656
ucp_ep_destroy (ep );
644
657
}
645
658
}
646
- if (progress_registered ) opal_progress_unregister (progress_callback );
647
- if (module ) free (module );
659
+ if (module ) {
660
+ free (module );
661
+ ompi_osc_ucx_unregister_progress ();
662
+ }
648
663
649
664
error_nomem :
650
665
if (env_initialized == true) {
@@ -812,6 +827,7 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) {
812
827
ompi_comm_free (& module -> comm );
813
828
814
829
free (module );
830
+ ompi_osc_ucx_unregister_progress ();
815
831
816
832
return ret ;
817
833
}
0 commit comments