@@ -29,6 +29,11 @@ typedef struct ucx_iovec {
29
29
size_t len ;
30
30
} ucx_iovec_t ;
31
31
32
+ OBJ_CLASS_INSTANCE (thread_local_info_t , opal_list_item_t , NULL , NULL );
33
+
34
+ __thread thread_local_info_t * my_thread_info = NULL ;
35
+ pthread_key_t my_thread_key = {0 };
36
+
32
37
static inline int check_sync_state (ompi_osc_ucx_module_t * module , int target ,
33
38
bool is_req_ops ) {
34
39
if (is_req_ops == false) {
@@ -367,19 +372,42 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
367
372
int target , ptrdiff_t target_disp , int target_count ,
368
373
struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
369
374
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
370
- ucp_ep_h ep = OSC_UCX_GET_EP ( module -> comm , target ) ;
375
+ ucp_ep_h ep ;
371
376
uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
372
377
ucp_rkey_h rkey ;
373
378
bool is_origin_contig = false, is_target_contig = false;
374
379
ptrdiff_t origin_lb , origin_extent , target_lb , target_extent ;
375
380
ucs_status_t status ;
381
+ pthread_t tid = pthread_self ();
376
382
int ret = OMPI_SUCCESS ;
377
383
378
384
ret = check_sync_state (module , target , false);
379
385
if (ret != OMPI_SUCCESS ) {
380
386
return ret ;
381
387
}
382
388
389
+ if (pthread_equal (tid , mca_osc_ucx_component .main_tid )) {
390
+ ep = OSC_UCX_GET_EP (module -> comm , target );
391
+ rkey = (module -> win_info_array [target ]).rkey ;
392
+ } else {
393
+ thread_local_info_t * curr_thread_info ;
394
+ if ((curr_thread_info = pthread_getspecific (my_thread_key )) == NULL ) {
395
+ ret = opal_common_ucx_create_local_worker (mca_osc_ucx_component .ucp_context ,
396
+ ompi_comm_size (module -> comm ),
397
+ mca_osc_ucx_component .worker_addr_buf ,
398
+ mca_osc_ucx_component .worker_addr_disps ,
399
+ mca_osc_ucx_component .mem_addr_buf ,
400
+ mca_osc_ucx_component .mem_addr_disps );
401
+ if (ret != OMPI_SUCCESS ) {
402
+ return ret ;
403
+ }
404
+ }
405
+
406
+ curr_thread_info = pthread_getspecific (my_thread_key );
407
+ rkey = curr_thread_info -> rkeys [target ];
408
+ ep = curr_thread_info -> eps [target ];
409
+ }
410
+
383
411
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
384
412
status = get_dynamic_win_info (remote_addr , module , ep , target );
385
413
if (status != UCS_OK ) {
@@ -393,8 +421,6 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
393
421
return OMPI_SUCCESS ;
394
422
}
395
423
396
- rkey = (module -> win_info_array [target ]).rkey ;
397
-
398
424
ompi_datatype_get_true_extent (origin_dt , & origin_lb , & origin_extent );
399
425
ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
400
426
@@ -427,19 +453,42 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
427
453
int target , ptrdiff_t target_disp , int target_count ,
428
454
struct ompi_datatype_t * target_dt , struct ompi_win_t * win ) {
429
455
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
430
- ucp_ep_h ep = OSC_UCX_GET_EP ( module -> comm , target ) ;
456
+ ucp_ep_h ep ;
431
457
uint64_t remote_addr = (module -> win_info_array [target ]).addr + target_disp * OSC_UCX_GET_DISP (module , target );
432
458
ucp_rkey_h rkey ;
433
459
ptrdiff_t origin_lb , origin_extent , target_lb , target_extent ;
434
460
bool is_origin_contig = false, is_target_contig = false;
435
461
ucs_status_t status ;
462
+ pthread_t tid = pthread_self ();
436
463
int ret = OMPI_SUCCESS ;
437
464
438
465
ret = check_sync_state (module , target , false);
439
466
if (ret != OMPI_SUCCESS ) {
440
467
return ret ;
441
468
}
442
469
470
+ if (pthread_equal (tid , mca_osc_ucx_component .main_tid )) {
471
+ ep = OSC_UCX_GET_EP (module -> comm , target );
472
+ rkey = (module -> win_info_array [target ]).rkey ;
473
+ } else {
474
+ thread_local_info_t * curr_thread_info ;
475
+ if ((curr_thread_info = pthread_getspecific (my_thread_key )) == NULL ) {
476
+ ret = opal_common_ucx_create_local_worker (mca_osc_ucx_component .ucp_context ,
477
+ ompi_comm_size (module -> comm ),
478
+ mca_osc_ucx_component .worker_addr_buf ,
479
+ mca_osc_ucx_component .worker_addr_disps ,
480
+ mca_osc_ucx_component .mem_addr_buf ,
481
+ mca_osc_ucx_component .mem_addr_disps );
482
+ if (ret != OMPI_SUCCESS ) {
483
+ return ret ;
484
+ }
485
+ }
486
+
487
+ curr_thread_info = pthread_getspecific (my_thread_key );
488
+ rkey = curr_thread_info -> rkeys [target ];
489
+ ep = curr_thread_info -> eps [target ];
490
+ }
491
+
443
492
if (module -> flavor == MPI_WIN_FLAVOR_DYNAMIC ) {
444
493
status = get_dynamic_win_info (remote_addr , module , ep , target );
445
494
if (status != UCS_OK ) {
@@ -453,8 +502,6 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
453
502
return OMPI_SUCCESS ;
454
503
}
455
504
456
- rkey = (module -> win_info_array [target ]).rkey ;
457
-
458
505
ompi_datatype_get_true_extent (origin_dt , & origin_lb , & origin_extent );
459
506
ompi_datatype_get_true_extent (target_dt , & target_lb , & target_extent );
460
507
0 commit comments