@@ -103,7 +103,7 @@ int mca_spml_ucx_enable(bool enable)
103
103
int mca_spml_ucx_del_procs (ompi_proc_t * * procs , size_t nprocs )
104
104
{
105
105
opal_common_ucx_del_proc_t * del_procs ;
106
- size_t i , j ;
106
+ size_t i ;
107
107
int ret ;
108
108
109
109
oshmem_shmem_barrier ();
@@ -118,12 +118,6 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs)
118
118
}
119
119
120
120
for (i = 0 ; i < nprocs ; ++ i ) {
121
- for (j = 0 ; j < MCA_MEMHEAP_SEG_COUNT ; j ++ ) {
122
- if (mca_spml_ucx_ctx_default .ucp_peers [i ].mkeys [j ].key .rkey != NULL ) {
123
- ucp_rkey_destroy (mca_spml_ucx_ctx_default .ucp_peers [i ].mkeys [j ].key .rkey );
124
- }
125
- }
126
-
127
121
del_procs [i ].ep = mca_spml_ucx_ctx_default .ucp_peers [i ].ucp_conn ;
128
122
del_procs [i ].vpid = i ;
129
123
@@ -349,16 +343,21 @@ spml_ucx_mkey_t * mca_spml_ucx_get_mkey_slow(shmem_ctx_t ctx, int pe, void *va,
349
343
350
344
void mca_spml_ucx_rmkey_free (sshmem_mkey_t * mkey )
351
345
{
346
+ spml_ucx_mkey_t * ucx_mkey ;
347
+
348
+ if (!mkey -> spml_context ) {
349
+ return ;
350
+ }
351
+ ucx_mkey = (spml_ucx_mkey_t * )(mkey -> spml_context );
352
+ ucp_rkey_destroy (ucx_mkey -> rkey );
352
353
}
353
354
354
- void * mca_spml_ucx_rmkey_ptr (const void * dst_addr , sshmem_mkey_t * key , int pe )
355
+ void * mca_spml_ucx_rmkey_ptr (const void * dst_addr , sshmem_mkey_t * mkey , int pe )
355
356
{
356
357
#if (((UCP_API_MAJOR >= 1 ) && (UCP_API_MINOR >= 3 )) || (UCP_API_MAJOR >= 2 ))
357
358
void * rva ;
358
359
ucs_status_t err ;
359
- mca_spml_ucx_ctx_t * ucx_ctx = (mca_spml_ucx_ctx_t * )& mca_spml_ucx_ctx_default ;
360
- uint32_t segno = memheap_find_segnum ((void * )dst_addr );
361
- spml_ucx_mkey_t * ucx_mkey = & ucx_ctx -> ucp_peers [pe ].mkeys [segno ].key ;
360
+ spml_ucx_mkey_t * ucx_mkey = (spml_ucx_mkey_t * )(mkey -> spml_context );
362
361
363
362
err = ucp_rkey_ptr (ucx_mkey -> rkey , (uint64_t )dst_addr , & rva );
364
363
if (UCS_OK != err ) {
@@ -386,6 +385,9 @@ void mca_spml_ucx_rmkey_unpack(shmem_ctx_t ctx, sshmem_mkey_t *mkey, uint32_t se
386
385
goto error_fatal ;
387
386
}
388
387
388
+ if (ucx_ctx == & mca_spml_ucx_ctx_default ) {
389
+ mkey -> spml_context = ucx_mkey ;
390
+ }
389
391
mca_spml_ucx_cache_mkey (ucx_ctx , mkey , segno , pe );
390
392
return ;
391
393
@@ -448,6 +450,7 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
448
450
mem_seg = memheap_find_seg (segno );
449
451
450
452
ucx_mkey = & mca_spml_ucx_ctx_default .ucp_peers [my_pe ].mkeys [segno ].key ;
453
+ mkeys [0 ].spml_context = ucx_mkey ;
451
454
452
455
/* if possible use mem handle already created by ucx allocator */
453
456
if (MAP_SEGMENT_ALLOC_UCX != mem_seg -> type ) {
@@ -514,12 +517,14 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
514
517
int my_pe = oshmem_my_proc_id ();
515
518
516
519
MCA_SPML_CALL (quiet (oshmem_ctx_default ));
517
- if (!mkeys || !mkeys [0 ].va_base )
520
+ if (!mkeys )
521
+ return OSHMEM_SUCCESS ;
522
+
523
+ if (!mkeys [0 ].spml_context )
518
524
return OSHMEM_SUCCESS ;
519
525
520
526
mem_seg = memheap_find_va (mkeys [0 ].va_base );
521
- segno = memheap_find_segnum (mkeys [0 ].va_base );
522
- ucx_mkey = & mca_spml_ucx_ctx_default .ucp_peers [my_pe ].mkeys [segno ].key ;
527
+ ucx_mkey = (spml_ucx_mkey_t * )mkeys [0 ].spml_context ;
523
528
524
529
if (OPAL_UNLIKELY (NULL == mem_seg )) {
525
530
return OSHMEM_ERROR ;
0 commit comments