Skip to content

Commit cf66f98

Browse files
committed
modifying osc mt code
1 parent efe72d3 commit cf66f98

File tree

5 files changed

+277
-43
lines changed

5 files changed

+277
-43
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,18 @@ typedef struct ompi_osc_ucx_component {
3434
ompi_osc_base_component_t super;
3535
ucp_context_h ucp_context;
3636
ucp_worker_h ucp_worker;
37+
pthread_mutex_t worker_mutex;
3738
bool enable_mpi_threads;
3839
opal_free_list_t requests; /* request free list for the r* communication variants */
3940
bool env_initialized; /* UCX environment is initialized or not */
4041
int num_incomplete_req_ops;
4142
int num_modules;
4243
unsigned int priority;
44+
char *worker_addr_buf;
45+
int *worker_addr_disps;
46+
char *mem_addr_buf;
47+
int *mem_addr_disps;
48+
pthread_t main_tid;
4349
} ompi_osc_ucx_component_t;
4450

4551
OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component;

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ typedef struct ucx_iovec {
2929
size_t len;
3030
} ucx_iovec_t;
3131

32+
OBJ_CLASS_INSTANCE(thread_local_info_t, opal_list_item_t, NULL, NULL);
33+
34+
__thread thread_local_info_t *my_thread_info = NULL;
35+
pthread_key_t my_thread_key = {0};
36+
3237
static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target,
3338
bool is_req_ops) {
3439
if (is_req_ops == false) {
@@ -367,19 +372,42 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
367372
int target, ptrdiff_t target_disp, int target_count,
368373
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
369374
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
370-
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
375+
ucp_ep_h ep;
371376
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
372377
ucp_rkey_h rkey;
373378
bool is_origin_contig = false, is_target_contig = false;
374379
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
375380
ucs_status_t status;
381+
pthread_t tid = pthread_self();
376382
int ret = OMPI_SUCCESS;
377383

378384
ret = check_sync_state(module, target, false);
379385
if (ret != OMPI_SUCCESS) {
380386
return ret;
381387
}
382388

389+
if (pthread_equal(tid, mca_osc_ucx_component.main_tid)) {
390+
ep = OSC_UCX_GET_EP(module->comm, target);
391+
rkey = (module->win_info_array[target]).rkey;
392+
} else {
393+
thread_local_info_t *curr_thread_info;
394+
if ((curr_thread_info = pthread_getspecific(my_thread_key)) == NULL) {
395+
ret = opal_common_ucx_create_local_worker(mca_osc_ucx_component.ucp_context,
396+
ompi_comm_size(module->comm),
397+
mca_osc_ucx_component.worker_addr_buf,
398+
mca_osc_ucx_component.worker_addr_disps,
399+
mca_osc_ucx_component.mem_addr_buf,
400+
mca_osc_ucx_component.mem_addr_disps);
401+
if (ret != OMPI_SUCCESS) {
402+
return ret;
403+
}
404+
}
405+
406+
curr_thread_info = pthread_getspecific(my_thread_key);
407+
rkey = curr_thread_info->rkeys[target];
408+
ep = curr_thread_info->eps[target];
409+
}
410+
383411
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
384412
status = get_dynamic_win_info(remote_addr, module, ep, target);
385413
if (status != UCS_OK) {
@@ -393,8 +421,6 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
393421
return OMPI_SUCCESS;
394422
}
395423

396-
rkey = (module->win_info_array[target]).rkey;
397-
398424
ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
399425
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
400426

@@ -427,19 +453,42 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
427453
int target, ptrdiff_t target_disp, int target_count,
428454
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
429455
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
430-
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
456+
ucp_ep_h ep;
431457
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
432458
ucp_rkey_h rkey;
433459
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
434460
bool is_origin_contig = false, is_target_contig = false;
435461
ucs_status_t status;
462+
pthread_t tid = pthread_self();
436463
int ret = OMPI_SUCCESS;
437464

438465
ret = check_sync_state(module, target, false);
439466
if (ret != OMPI_SUCCESS) {
440467
return ret;
441468
}
442469

470+
if (pthread_equal(tid, mca_osc_ucx_component.main_tid)) {
471+
ep = OSC_UCX_GET_EP(module->comm, target);
472+
rkey = (module->win_info_array[target]).rkey;
473+
} else {
474+
thread_local_info_t *curr_thread_info;
475+
if ((curr_thread_info = pthread_getspecific(my_thread_key)) == NULL) {
476+
ret = opal_common_ucx_create_local_worker(mca_osc_ucx_component.ucp_context,
477+
ompi_comm_size(module->comm),
478+
mca_osc_ucx_component.worker_addr_buf,
479+
mca_osc_ucx_component.worker_addr_disps,
480+
mca_osc_ucx_component.mem_addr_buf,
481+
mca_osc_ucx_component.mem_addr_disps);
482+
if (ret != OMPI_SUCCESS) {
483+
return ret;
484+
}
485+
}
486+
487+
curr_thread_info = pthread_getspecific(my_thread_key);
488+
rkey = curr_thread_info->rkeys[target];
489+
ep = curr_thread_info->eps[target];
490+
}
491+
443492
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
444493
status = get_dynamic_win_info(remote_addr, module, ep, target);
445494
if (status != UCS_OK) {
@@ -453,8 +502,6 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
453502
return OMPI_SUCCESS;
454503
}
455504

456-
rkey = (module->win_info_array[target]).rkey;
457-
458505
ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
459506
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);
460507

0 commit comments

Comments
 (0)