@@ -207,10 +207,71 @@ static int progress_callback(void) {
207
207
return 0 ;
208
208
}
209
209
210
+ static int ucp_context_init (bool enable_mt , int proc_world_size ) {
211
+ int ret = OMPI_SUCCESS ;
212
+ ucs_status_t status ;
213
+ ucp_config_t * config = NULL ;
214
+ ucp_params_t context_params ;
215
+
216
+ status = ucp_config_read ("MPI" , NULL , & config );
217
+ if (UCS_OK != status ) {
218
+ OSC_UCX_VERBOSE (1 , "ucp_config_read failed: %d" , status );
219
+ return OMPI_ERROR ;
220
+ }
221
+
222
+ /* initialize UCP context */
223
+ memset (& context_params , 0 , sizeof (context_params ));
224
+ context_params .field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_MT_WORKERS_SHARED
225
+ | UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | UCP_PARAM_FIELD_REQUEST_INIT
226
+ | UCP_PARAM_FIELD_REQUEST_SIZE ;
227
+ context_params .features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64 ;
228
+ context_params .mt_workers_shared = (enable_mt ? 1 : 0 );
229
+ context_params .estimated_num_eps = proc_world_size ;
230
+ context_params .request_init = opal_common_ucx_req_init ;
231
+ context_params .request_size = sizeof (opal_common_ucx_request_t );
232
+
233
+ #if HAVE_DECL_UCP_PARAM_FIELD_ESTIMATED_NUM_PPN
234
+ context_params .estimated_num_ppn = opal_process_info .num_local_peers + 1 ;
235
+ context_params .field_mask |= UCP_PARAM_FIELD_ESTIMATED_NUM_PPN ;
236
+ #endif
237
+
238
+ status = ucp_init (& context_params , config , & mca_osc_ucx_component .wpool -> ucp_ctx );
239
+ if (UCS_OK != status ) {
240
+ OSC_UCX_VERBOSE (1 , "ucp_init failed: %d" , status );
241
+ ret = OMPI_ERROR ;
242
+ }
243
+ ucp_config_release (config );
244
+
245
+ return ret ;
246
+ }
247
+
210
248
static int component_init (bool enable_progress_threads , bool enable_mpi_threads ) {
249
+ opal_common_ucx_support_level_t support_level ;
250
+ int ret = OMPI_SUCCESS ;
251
+
211
252
mca_osc_ucx_component .enable_mpi_threads = enable_mpi_threads ;
212
253
mca_osc_ucx_component .wpool = opal_common_ucx_wpool_allocate ();
213
254
opal_common_ucx_mca_register ();
255
+
256
+ ret = ucp_context_init (enable_mpi_threads , ompi_proc_world_size ());
257
+ if (OMPI_ERROR == ret ) {
258
+ return OMPI_ERR_NOT_AVAILABLE ;
259
+ }
260
+
261
+ support_level = opal_common_ucx_support_level (mca_osc_ucx_component .wpool -> ucp_ctx );
262
+ if (OPAL_COMMON_UCX_SUPPORT_NONE == support_level ) {
263
+ ucp_cleanup (mca_osc_ucx_component .wpool -> ucp_ctx );
264
+ mca_osc_ucx_component .wpool -> ucp_ctx = NULL ;
265
+ return OMPI_ERR_NOT_AVAILABLE ;
266
+ }
267
+
268
+ /*
269
+ * Retain priority if we have supported devices and transports.
270
+ * Lower priority if we have supported transports, but not supported devices.
271
+ */
272
+ mca_osc_ucx_component .priority = (support_level == OPAL_COMMON_UCX_SUPPORT_DEVICE ) ?
273
+ mca_osc_ucx_component .priority : 19 ;
274
+ OSC_UCX_VERBOSE (2 , "returning priority %d" , mca_osc_ucx_component .priority );
214
275
return OMPI_SUCCESS ;
215
276
}
216
277
@@ -395,9 +456,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
395
456
goto select_unlock ;
396
457
}
397
458
398
- ret = opal_common_ucx_wpool_init (mca_osc_ucx_component .wpool ,
399
- ompi_proc_world_size (),
400
- mca_osc_ucx_component .enable_mpi_threads );
459
+ ret = opal_common_ucx_wpool_init (mca_osc_ucx_component .wpool );
401
460
if (OMPI_SUCCESS != ret ) {
402
461
OSC_UCX_VERBOSE (1 , "opal_common_ucx_wpool_init failed: %d" , ret );
403
462
goto select_unlock ;
0 commit comments