16
16
17
17
#include "opal/runtime/opal.h"
18
18
#include "opal/mca/pmix/pmix.h"
19
+ #include "ompi/attribute/attribute.h"
19
20
#include "ompi/message/message.h"
20
21
#include "ompi/mca/pml/base/pml_base_bsend.h"
21
22
#include "opal/mca/common/ucx/common_ucx.h"
@@ -190,9 +191,9 @@ int mca_pml_ucx_close(void)
190
191
int mca_pml_ucx_init (void )
191
192
{
192
193
ucp_worker_params_t params ;
193
- ucs_status_t status ;
194
194
ucp_worker_attr_t attr ;
195
- int rc ;
195
+ ucs_status_t status ;
196
+ int i , rc ;
196
197
197
198
PML_UCX_VERBOSE (1 , "mca_pml_ucx_init" );
198
199
@@ -209,30 +210,34 @@ int mca_pml_ucx_init(void)
209
210
& ompi_pml_ucx .ucp_worker );
210
211
if (UCS_OK != status ) {
211
212
PML_UCX_ERROR ("Failed to create UCP worker" );
212
- return OMPI_ERROR ;
213
+ rc = OMPI_ERROR ;
214
+ goto err ;
213
215
}
214
216
215
217
attr .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE ;
216
218
status = ucp_worker_query (ompi_pml_ucx .ucp_worker , & attr );
217
219
if (UCS_OK != status ) {
218
- ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
219
- ompi_pml_ucx .ucp_worker = NULL ;
220
220
PML_UCX_ERROR ("Failed to query UCP worker thread level" );
221
- return OMPI_ERROR ;
221
+ rc = OMPI_ERROR ;
222
+ goto err_destroy_worker ;
222
223
}
223
224
224
- if (ompi_mpi_thread_multiple && attr .thread_mode != UCS_THREAD_MODE_MULTI ) {
225
+ if (ompi_mpi_thread_multiple && ( attr .thread_mode != UCS_THREAD_MODE_MULTI ) ) {
225
226
/* UCX does not support multithreading, disqualify current PML for now */
226
227
/* TODO: we should let OMPI to fallback to THREAD_SINGLE mode */
227
- ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
228
- ompi_pml_ucx .ucp_worker = NULL ;
229
228
PML_UCX_ERROR ("UCP worker does not support MPI_THREAD_MULTIPLE" );
230
- return OMPI_ERROR ;
229
+ rc = OMPI_ERR_NOT_SUPPORTED ;
230
+ goto err_destroy_worker ;
231
231
}
232
232
233
233
rc = mca_pml_ucx_send_worker_address ();
234
234
if (rc < 0 ) {
235
- return rc ;
235
+ goto err_destroy_worker ;
236
+ }
237
+
238
+ ompi_pml_ucx .datatype_attr_keyval = MPI_KEYVAL_INVALID ;
239
+ for (i = 0 ; i < OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
240
+ ompi_pml_ucx .predefined_types [i ] = PML_UCX_DATATYPE_INVALID ;
236
241
}
237
242
238
243
/* Initialize the free lists */
@@ -249,14 +254,33 @@ int mca_pml_ucx_init(void)
249
254
(void * )ompi_pml_ucx .ucp_context ,
250
255
(void * )ompi_pml_ucx .ucp_worker );
251
256
return OMPI_SUCCESS ;
257
+
258
+ err_destroy_worker :
259
+ ucp_worker_destroy (ompi_pml_ucx .ucp_worker );
260
+ ompi_pml_ucx .ucp_worker = NULL ;
261
+ err :
262
+ return OMPI_ERROR ;
252
263
}
253
264
254
265
int mca_pml_ucx_cleanup (void )
255
266
{
267
+ int i ;
268
+
256
269
PML_UCX_VERBOSE (1 , "mca_pml_ucx_cleanup" );
257
270
258
271
opal_progress_unregister (mca_pml_ucx_progress );
259
272
273
+ if (ompi_pml_ucx .datatype_attr_keyval != MPI_KEYVAL_INVALID ) {
274
+ ompi_attr_free_keyval (TYPE_ATTR , & ompi_pml_ucx .datatype_attr_keyval , false);
275
+ }
276
+
277
+ for (i = 0 ; i < OMPI_DATATYPE_MAX_PREDEFINED ; ++ i ) {
278
+ if (ompi_pml_ucx .predefined_types [i ] != PML_UCX_DATATYPE_INVALID ) {
279
+ ucp_dt_destroy (ompi_pml_ucx .predefined_types [i ]);
280
+ ompi_pml_ucx .predefined_types [i ] = PML_UCX_DATATYPE_INVALID ;
281
+ }
282
+ }
283
+
260
284
ompi_pml_ucx .completed_send_req .req_state = OMPI_REQUEST_INVALID ;
261
285
OMPI_REQUEST_FINI (& ompi_pml_ucx .completed_send_req );
262
286
OBJ_DESTRUCT (& ompi_pml_ucx .completed_send_req );
@@ -398,6 +422,22 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
398
422
399
423
int mca_pml_ucx_enable (bool enable )
400
424
{
425
+ ompi_attribute_fn_ptr_union_t copy_fn ;
426
+ ompi_attribute_fn_ptr_union_t del_fn ;
427
+ int ret ;
428
+
429
+ /* Create a key for adding custom attributes to datatypes */
430
+ copy_fn .attr_datatype_copy_fn =
431
+ (MPI_Type_internal_copy_attr_function * )MPI_TYPE_NULL_COPY_FN ;
432
+ del_fn .attr_datatype_delete_fn = mca_pml_ucx_datatype_attr_del_fn ;
433
+ ret = ompi_attr_create_keyval (TYPE_ATTR , copy_fn , del_fn ,
434
+ & ompi_pml_ucx .datatype_attr_keyval , NULL , 0 ,
435
+ NULL );
436
+ if (ret != OMPI_SUCCESS ) {
437
+ PML_UCX_ERROR ("Failed to create keyval for UCX datatypes: %d" , ret );
438
+ return ret ;
439
+ }
440
+
401
441
PML_UCX_FREELIST_INIT (& ompi_pml_ucx .persistent_reqs ,
402
442
mca_pml_ucx_persistent_request_t ,
403
443
128 , -1 , 128 );
0 commit comments