diff --git a/contrib/platform/mellanox/optimized b/contrib/platform/mellanox/optimized index f49a0576c64..339518e483e 100644 --- a/contrib/platform/mellanox/optimized +++ b/contrib/platform/mellanox/optimized @@ -1,4 +1,4 @@ -enable_mca_no_build=coll-ml,btl-uct +enable_mca_no_build=coll-ml enable_debug_symbols=yes enable_orterun_prefix_by_default=yes with_verbs=no diff --git a/ompi/mca/osc/ucx/osc_ucx.h b/ompi/mca/osc/ucx/osc_ucx.h index 44dff95a845..5df4d6e8981 100644 --- a/ompi/mca/osc/ucx/osc_ucx.h +++ b/ompi/mca/osc/ucx/osc_ucx.h @@ -40,6 +40,12 @@ typedef struct ompi_osc_ucx_component { int num_incomplete_req_ops; int num_modules; unsigned int priority; + pthread_mutex_t worker_mutex; + char *worker_addr_buf; + int *worker_addr_disps; + char *mem_addr_buf; + int *mem_addr_disps; + pthread_t main_tid; } ompi_osc_ucx_component_t; OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component; diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index ec760d4fda3..a604f6e94cf 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -29,6 +29,10 @@ typedef struct ucx_iovec { size_t len; } ucx_iovec_t; +OBJ_CLASS_INSTANCE(thread_local_info_t, opal_list_item_t, NULL, NULL); + +pthread_key_t my_thread_key = {0}; + static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target, bool is_req_ops) { if (is_req_ops == false) { @@ -83,6 +87,40 @@ static inline int incr_and_check_ops_num(ompi_osc_ucx_module_t *module, int targ return OMPI_SUCCESS; } +static inline int get_osc_metadata(ompi_osc_ucx_module_t *module, int target, + ucp_ep_h *ep_ptr, ucp_rkey_h *rkey_ptr, + pthread_mutex_t **mutex_pptr) { + pthread_t tid = pthread_self(); + int ret = OMPI_SUCCESS; + + if (pthread_equal(tid, mca_osc_ucx_component.main_tid)) { + (*ep_ptr) = OSC_UCX_GET_EP(module->comm, target); + (*rkey_ptr) = (module->win_info_array[target]).rkey; + (*mutex_pptr) = &mca_osc_ucx_component.worker_mutex; + } else { + thread_local_info_t *curr_thread_info = NULL; + if ((curr_thread_info = pthread_getspecific(my_thread_key)) == NULL) { + ret = opal_common_ucx_create_local_worker(mca_osc_ucx_component.ucp_context, + ompi_comm_size(module->comm), + mca_osc_ucx_component.worker_addr_buf, + mca_osc_ucx_component.worker_addr_disps, + mca_osc_ucx_component.mem_addr_buf, + mca_osc_ucx_component.mem_addr_disps); + if (ret != OMPI_SUCCESS) { + return ret; + } + curr_thread_info = pthread_getspecific(my_thread_key); + } + + assert(curr_thread_info != NULL); + (*rkey_ptr) = curr_thread_info->rkeys[target]; + (*ep_ptr) = curr_thread_info->eps[target]; + (*mutex_pptr) = &curr_thread_info->lock; + } + + return ret; +} + static inline int create_iov_list(const void *addr, int count, ompi_datatype_t *datatype, ucx_iovec_t **ucx_iov, uint32_t *ucx_iov_count) { int ret = OMPI_SUCCESS; @@ -139,7 +177,8 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, bool is_origin_contig, ptrdiff_t origin_lb, int target, ucp_ep_h ep, uint64_t remote_addr, ucp_rkey_h rkey, int target_count, struct ompi_datatype_t *target_dt, - bool is_target_contig, ptrdiff_t target_lb, bool is_get) { + bool is_target_contig, ptrdiff_t target_lb, bool is_get, + pthread_mutex_t *mutex_ptr) { ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL; uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0; uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0; @@ -168,6 +207,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, curr_len = MIN(origin_ucx_iov[origin_ucx_iov_idx].len, target_ucx_iov[target_ucx_iov_idx].len); + pthread_mutex_lock(mutex_ptr); if (!is_get) { status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len, remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey); @@ -183,6 +223,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, return OMPI_ERROR; } } + pthread_mutex_unlock(mutex_ptr); ret = incr_and_check_ops_num(module, target, ep); if (ret != OMPI_SUCCESS) { @@ -208,6 +249,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, } else if (!is_origin_contig) { size_t prev_len = 0; while (origin_ucx_iov_idx < origin_ucx_iov_count) { + pthread_mutex_lock(mutex_ptr); if (!is_get) { status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, origin_ucx_iov[origin_ucx_iov_idx].len, @@ -225,6 +267,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, return OMPI_ERROR; } } + pthread_mutex_unlock(mutex_ptr); ret = incr_and_check_ops_num(module, target, ep); if (ret != OMPI_SUCCESS) { @@ -237,6 +280,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, } else { size_t prev_len = 0; while (target_ucx_iov_idx < target_ucx_iov_count) { + pthread_mutex_lock(mutex_ptr); if (!is_get) { status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb + prev_len), target_ucx_iov[target_ucx_iov_idx].len, @@ -254,6 +298,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, return OMPI_ERROR; } } + pthread_mutex_unlock(mutex_ptr); ret = incr_and_check_ops_num(module, target, ep); if (ret != OMPI_SUCCESS) { @@ -367,12 +412,13 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); + ucp_ep_h ep; uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); ucp_rkey_h rkey; bool is_origin_contig = false, is_target_contig = false; ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; ucs_status_t status; + pthread_mutex_t *mutex_ptr = NULL; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -380,6 +426,11 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data return ret; } + ret = get_osc_metadata(module, target, &ep, &rkey, &mutex_ptr); + if (ret != OMPI_SUCCESS) { + return ret; + } + if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { @@ -393,8 +444,6 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data return OMPI_SUCCESS; } - rkey = (module->win_info_array[target]).rkey; - ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent); @@ -408,17 +457,19 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data ompi_datatype_type_size(origin_dt, &origin_len); origin_len *= origin_count; + pthread_mutex_lock(mutex_ptr); status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len, remote_addr + target_lb, rkey); if (status != UCS_OK && status != UCS_INPROGRESS) { OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); return OMPI_ERROR; } + pthread_mutex_unlock(mutex_ptr); return incr_and_check_ops_num(module, target, ep); } else { return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig, origin_lb, target, ep, remote_addr, rkey, target_count, target_dt, - is_target_contig, target_lb, false); + is_target_contig, target_lb, false, mutex_ptr); } } @@ -427,12 +478,13 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); + ucp_ep_h ep; uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); ucp_rkey_h rkey; ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; bool is_origin_contig = false, is_target_contig = false; ucs_status_t status; + pthread_mutex_t *mutex_ptr = NULL; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -440,6 +492,11 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, return ret; } + ret = get_osc_metadata(module, target, &ep, &rkey, &mutex_ptr); + if (ret != OMPI_SUCCESS) { + return ret; + } + if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { @@ -453,8 +510,6 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, return OMPI_SUCCESS; } - rkey = (module->win_info_array[target]).rkey; - ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent); @@ -468,18 +523,20 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, ompi_datatype_type_size(origin_dt, &origin_len); origin_len *= origin_count; + pthread_mutex_lock(mutex_ptr); status = ucp_get_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len, remote_addr + target_lb, rkey); if (status != UCS_OK && status != UCS_INPROGRESS) { OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status); return OMPI_ERROR; } + pthread_mutex_unlock(mutex_ptr); return incr_and_check_ops_num(module, target, ep); } else { return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig, origin_lb, target, ep, remote_addr, rkey, target_count, target_dt, - is_target_contig, target_lb, true); + is_target_contig, target_lb, true, mutex_ptr); } } diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 6fd3291bad0..5715cdfacd4 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -35,6 +35,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in int flavor, int *model); static void ompi_osc_ucx_unregister_progress(void); +opal_list_t active_workers = {{0}}, idle_workers = {{0}}; +pthread_mutex_t active_workers_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t idle_workers_mutex = PTHREAD_MUTEX_INITIALIZER; + ompi_osc_ucx_component_t mca_osc_ucx_component = { { /* ompi_osc_base_component_t */ .osc_version = { @@ -58,7 +62,12 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = { .ucp_worker = NULL, .env_initialized = false, .num_incomplete_req_ops = 0, - .num_modules = 0 + .num_modules = 0, + .worker_mutex = PTHREAD_MUTEX_INITIALIZER, + .worker_addr_buf = NULL, + .worker_addr_disps = NULL, + .mem_addr_buf = NULL, + .mem_addr_disps = NULL }; ompi_osc_ucx_module_t ompi_osc_ucx_module_template = { @@ -126,13 +135,84 @@ static int progress_callback(void) { static int component_init(bool enable_progress_threads, bool enable_mpi_threads) { mca_osc_ucx_component.enable_mpi_threads = enable_mpi_threads; + mca_osc_ucx_component.main_tid = pthread_self(); + + OBJ_CONSTRUCT(&active_workers, opal_list_t); + OBJ_CONSTRUCT(&idle_workers, opal_list_t); + + pthread_mutex_init(&active_workers_mutex, NULL); + pthread_mutex_init(&idle_workers_mutex, NULL); + pthread_mutex_init(&mca_osc_ucx_component.worker_mutex, NULL); + + pthread_key_create(&my_thread_key, opal_common_ucx_cleanup_local_worker); opal_common_ucx_mca_register(); return OMPI_SUCCESS; } +static void cleanup_thread_local_info(thread_local_info_t *curr_worker) { + int i; + + if (curr_worker->rkeys != NULL) { + for (i = 0; i < curr_worker->comm_size; i++) { + if (curr_worker->rkeys[i] != NULL) { + ucp_rkey_destroy(curr_worker->rkeys[i]); + } + } + free(curr_worker->rkeys); + } + + if (curr_worker->eps != NULL) { + for (i = 0; i < curr_worker->comm_size; i++) { + if (curr_worker->eps[i] != NULL) { + ucp_ep_destroy(curr_worker->eps[i]); + } + } + free(curr_worker->eps); + } + + if (curr_worker->worker != NULL) { + ucp_worker_destroy(curr_worker->worker); + } + + pthread_mutex_destroy(&curr_worker->lock); +} + static int component_finalize(void) { int i; + + if (!opal_list_is_empty(&active_workers)) { + thread_local_info_t *curr_worker, *next; + OPAL_LIST_FOREACH_SAFE(curr_worker, next, &active_workers, thread_local_info_t) { + opal_list_remove_item(&active_workers, &curr_worker->super); + cleanup_thread_local_info(curr_worker); + } + } + OBJ_DESTRUCT(&active_workers); + + if (!opal_list_is_empty(&idle_workers)) { + thread_local_info_t *curr_worker, *next; + OPAL_LIST_FOREACH_SAFE(curr_worker, next, &idle_workers, thread_local_info_t) { + opal_list_remove_item(&idle_workers, &curr_worker->super); + cleanup_thread_local_info(curr_worker); + } + } + OBJ_DESTRUCT(&idle_workers); + + pthread_mutex_destroy(&active_workers_mutex); + pthread_mutex_destroy(&idle_workers_mutex); + pthread_mutex_destroy(&mca_osc_ucx_component.worker_mutex); + pthread_key_delete(my_thread_key); + + if (mca_osc_ucx_component.worker_addr_buf != NULL) + free(mca_osc_ucx_component.worker_addr_buf); + if (mca_osc_ucx_component.worker_addr_disps != NULL) + free(mca_osc_ucx_component.worker_addr_disps); + if (mca_osc_ucx_component.mem_addr_buf != NULL) + free(mca_osc_ucx_component.mem_addr_buf); + if (mca_osc_ucx_component.mem_addr_disps != NULL) + free(mca_osc_ucx_component.mem_addr_disps); + for (i = 0; i < ompi_proc_world_size(); i++) { ucp_ep_h ep = OSC_UCX_GET_EP(&(ompi_mpi_comm_world.comm), i); if (ep != NULL) { @@ -273,13 +353,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in bool eps_created = false, env_initialized = false; ucp_address_t *my_addr = NULL; size_t my_addr_len; - char *recv_buf = NULL; void *rkey_buffer = NULL, *state_rkey_buffer = NULL; size_t rkey_buffer_size, state_rkey_buffer_size; void *state_base = NULL; void * my_info = NULL; size_t my_info_len; - int disps[comm_size]; int rkey_sizes[comm_size]; uint64_t zero = 0; size_t info_offset; @@ -295,7 +373,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in ucp_config_t *config = NULL; ucp_params_t context_params; ucp_worker_params_t worker_params; - ucp_worker_attr_t worker_attr; status = ucp_config_read("MPI", NULL, &config); if (UCS_OK != status) { @@ -323,7 +400,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in UCP_PARAM_FIELD_REQUEST_INIT | UCP_PARAM_FIELD_REQUEST_SIZE; context_params.features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64; - context_params.mt_workers_shared = 0; + context_params.mt_workers_shared = 1; context_params.estimated_num_eps = ompi_proc_world_size(); context_params.request_init = internal_req_init; context_params.request_size = sizeof(ompi_osc_ucx_internal_request_t); @@ -339,8 +416,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in assert(mca_osc_ucx_component.ucp_worker == NULL); memset(&worker_params, 0, sizeof(worker_params)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = (mca_osc_ucx_component.enable_mpi_threads == true) - ? UCS_THREAD_MODE_MULTI : UCS_THREAD_MODE_SINGLE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; status = ucp_worker_create(mca_osc_ucx_component.ucp_context, &worker_params, &(mca_osc_ucx_component.ucp_worker)); if (UCS_OK != status) { @@ -349,22 +425,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error_nomem; } - /* query UCP worker attributes */ - worker_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; - status = ucp_worker_query(mca_osc_ucx_component.ucp_worker, &worker_attr); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_worker_query failed: %d", status); - ret = OMPI_ERROR; - goto error_nomem; - } - - if (mca_osc_ucx_component.enable_mpi_threads == true && - worker_attr.thread_mode != UCS_THREAD_MODE_MULTI) { - OSC_UCX_VERBOSE(1, "ucx does not support multithreading"); - ret = OMPI_ERROR; - goto error_nomem; - } - mca_osc_ucx_component.env_initialized = true; env_initialized = true; } @@ -442,6 +502,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error; } + if (mca_osc_ucx_component.worker_addr_disps == NULL) + mca_osc_ucx_component.worker_addr_disps = malloc(comm_size * sizeof(int)); + if (mca_osc_ucx_component.mem_addr_disps == NULL) + mca_osc_ucx_component.mem_addr_disps = malloc(comm_size * sizeof(int)); + if (!is_eps_ready) { status = ucp_worker_get_address(mca_osc_ucx_component.ucp_worker, &my_addr, &my_addr_len); @@ -451,8 +516,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error; } + assert(mca_osc_ucx_component.worker_addr_buf == NULL); ret = allgather_len_and_info(my_addr, (int)my_addr_len, - &recv_buf, disps, module->comm); + &(mca_osc_ucx_component.worker_addr_buf), + mca_osc_ucx_component.worker_addr_disps, module->comm); if (ret != OMPI_SUCCESS) { goto error; } @@ -461,9 +528,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in if (OSC_UCX_GET_EP(module->comm, i) == NULL) { ucp_ep_params_t ep_params; ucp_ep_h ep; + info_offset = mca_osc_ucx_component.worker_addr_disps[i]; memset(&ep_params, 0, sizeof(ucp_ep_params_t)); ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = (ucp_address_t *)&(recv_buf[disps[i]]); + ep_params.address = (ucp_address_t *)&(mca_osc_ucx_component.worker_addr_buf[info_offset]); status = ucp_ep_create(mca_osc_ucx_component.ucp_worker, &ep_params, &ep); if (status != UCS_OK) { OSC_UCX_VERBOSE(1, "ucp_ep_create failed: %d", status); @@ -477,9 +545,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in ucp_worker_release_address(mca_osc_ucx_component.ucp_worker, my_addr); my_addr = NULL; - free(recv_buf); - recv_buf = NULL; - eps_created = true; } @@ -549,7 +614,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in assert(my_info_len == info_offset); - ret = allgather_len_and_info(my_info, (int)my_info_len, &recv_buf, disps, module->comm); + assert(mca_osc_ucx_component.mem_addr_buf == NULL); + ret = allgather_len_and_info(my_info, (int)my_info_len, + &mca_osc_ucx_component.mem_addr_buf, + mca_osc_ucx_component.mem_addr_disps, module->comm); if (ret != OMPI_SUCCESS) { goto error; } @@ -566,18 +634,20 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in uint64_t dest_size; assert(ep != NULL); - info_offset = disps[i]; + info_offset = mca_osc_ucx_component.mem_addr_disps[i]; - memcpy(&(module->win_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t)); + memcpy(&(module->win_info_array[i]).addr, + &mca_osc_ucx_component.mem_addr_buf[info_offset], sizeof(uint64_t)); info_offset += sizeof(uint64_t); - memcpy(&(module->state_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t)); + memcpy(&(module->state_info_array[i]).addr, + &mca_osc_ucx_component.mem_addr_buf[info_offset], sizeof(uint64_t)); info_offset += sizeof(uint64_t); - memcpy(&dest_size, &recv_buf[info_offset], sizeof(uint64_t)); + memcpy(&dest_size, &mca_osc_ucx_component.mem_addr_buf[info_offset], sizeof(uint64_t)); info_offset += sizeof(uint64_t); (module->win_info_array[i]).rkey_init = false; if (dest_size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) { - status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset], + status = ucp_ep_rkey_unpack(ep, &mca_osc_ucx_component.mem_addr_buf[info_offset], &((module->win_info_array[i]).rkey)); if (status != UCS_OK) { OSC_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status); @@ -588,7 +658,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in (module->win_info_array[i]).rkey_init = true; } - status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset], + status = ucp_ep_rkey_unpack(ep, &mca_osc_ucx_component.mem_addr_buf[info_offset], &((module->state_info_array[i]).rkey)); if (status != UCS_OK) { OSC_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status); @@ -599,7 +669,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in } free(my_info); - free(recv_buf); if (rkey_buffer_size != 0) { ucp_rkey_buffer_release(rkey_buffer); @@ -656,7 +725,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in error: if (my_addr) ucp_worker_release_address(mca_osc_ucx_component.ucp_worker, my_addr); - if (recv_buf) free(recv_buf); if (my_info) free(my_info); for (i = 0; i < comm_size; i++) { if ((module->win_info_array[i]).rkey != NULL) { diff --git a/ompi/mca/osc/ucx/osc_ucx_passive_target.c b/ompi/mca/osc/ucx/osc_ucx_passive_target.c index 3a7ad3e9e24..2b8dc97c874 100644 --- a/ompi/mca/osc/ucx/osc_ucx_passive_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_passive_target.c @@ -294,11 +294,29 @@ int ompi_osc_ucx_flush(int target, struct ompi_win_t *win) { return OMPI_ERR_RMA_SYNC; } + pthread_mutex_lock(&mca_osc_ucx_component.worker_mutex); ep = OSC_UCX_GET_EP(module->comm, target); ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); if (ret != OMPI_SUCCESS) { return ret; } + pthread_mutex_unlock(&mca_osc_ucx_component.worker_mutex); + + + pthread_mutex_lock(&active_workers_mutex); + if (!opal_list_is_empty(&active_workers)) { + thread_local_info_t *curr_worker, *next; + OPAL_LIST_FOREACH_SAFE(curr_worker, next, &active_workers, thread_local_info_t) { + pthread_mutex_lock(&curr_worker->lock); + ep = curr_worker->eps[target]; + ret = opal_common_ucx_ep_flush(ep, curr_worker->worker); + if (ret != OMPI_SUCCESS) { + return ret; + } + pthread_mutex_unlock(&curr_worker->lock); + } + } + pthread_mutex_unlock(&active_workers_mutex); module->global_ops_num -= module->per_target_ops_nums[target]; module->per_target_ops_nums[target] = 0; @@ -315,10 +333,26 @@ int ompi_osc_ucx_flush_all(struct ompi_win_t *win) { return OMPI_ERR_RMA_SYNC; } + pthread_mutex_lock(&mca_osc_ucx_component.worker_mutex); ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker); if (ret != OMPI_SUCCESS) { return ret; } + pthread_mutex_unlock(&mca_osc_ucx_component.worker_mutex); + + pthread_mutex_lock(&active_workers_mutex); + if (!opal_list_is_empty(&active_workers)) { + thread_local_info_t *curr_worker, *next; + OPAL_LIST_FOREACH_SAFE(curr_worker, next, &active_workers, thread_local_info_t) { + pthread_mutex_lock(&curr_worker->lock); + ret = opal_common_ucx_worker_flush(curr_worker->worker); + if (ret != OMPI_SUCCESS) { + return ret; + } + pthread_mutex_unlock(&curr_worker->lock); + } + } + pthread_mutex_unlock(&active_workers_mutex); module->global_ops_num = 0; memset(module->per_target_ops_nums, 0, diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index e25dd23b821..b46269bab6b 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -16,6 +16,7 @@ #include "opal_config.h" #include +#include #include @@ -96,6 +97,22 @@ typedef struct opal_common_ucx_del_proc { extern opal_common_ucx_module_t opal_common_ucx; +typedef struct thread_local_info { + opal_list_item_t super; + ucp_worker_h worker; + ucp_ep_h *eps; + ucp_rkey_h *rkeys; + int comm_size; + pthread_mutex_t lock; +} thread_local_info_t; + +OBJ_CLASS_DECLARATION(thread_local_info_t); + +extern pthread_key_t my_thread_key; + +extern opal_list_t active_workers, idle_workers; +extern pthread_mutex_t active_workers_mutex, idle_workers_mutex; + OPAL_DECLSPEC void opal_common_ucx_mca_register(void); OPAL_DECLSPEC void opal_common_ucx_mca_deregister(void); OPAL_DECLSPEC void opal_common_ucx_empty_complete_cb(void *request, ucs_status_t status); @@ -202,6 +219,83 @@ int opal_common_ucx_atomic_cswap(ucp_ep_h ep, uint64_t compare, return ret; } +static inline void opal_common_ucx_cleanup_local_worker(void *arg) { + thread_local_info_t *my_thread_info = (thread_local_info_t *)arg; + + assert(my_thread_info != NULL); + + pthread_mutex_lock(&active_workers_mutex); + opal_list_remove_item(&active_workers, &my_thread_info->super); + pthread_mutex_unlock(&active_workers_mutex); + + pthread_mutex_lock(&idle_workers_mutex); + opal_list_append(&idle_workers, &my_thread_info->super); + pthread_mutex_unlock(&idle_workers_mutex); +} + +static inline int opal_common_ucx_create_local_worker(ucp_context_h context, int comm_size, + char *worker_buf, int *worker_disps, + char *mem_buf, int *mem_disps) +{ + ucp_worker_params_t worker_params; + ucs_status_t status; + thread_local_info_t *my_thread_info; + int i, ret = OPAL_SUCCESS; + + if (!opal_list_is_empty(&idle_workers)) { + pthread_mutex_lock(&idle_workers_mutex); + my_thread_info = (thread_local_info_t *)opal_list_get_first(&idle_workers); + opal_list_remove_item(&idle_workers, &my_thread_info->super); + pthread_mutex_unlock(&idle_workers_mutex); + } else { + my_thread_info = OBJ_NEW(thread_local_info_t); + memset(my_thread_info, 0, sizeof(thread_local_info_t)); + pthread_mutex_init(&(my_thread_info->lock), NULL); + + my_thread_info->comm_size = comm_size; + + memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + status = ucp_worker_create(context, &worker_params, + &(my_thread_info->worker)); + if (UCS_OK != status) { + ret = OPAL_ERROR; + } + + my_thread_info->eps = calloc(comm_size, sizeof(ucp_ep_h)); + my_thread_info->rkeys = calloc(comm_size, sizeof(ucp_rkey_h)); + + for (i = 0; i < comm_size; i++) { + ucp_ep_params_t ep_params; + + memset(&ep_params, 0, sizeof(ucp_ep_params_t)); + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + ep_params.address = (ucp_address_t *)&(worker_buf[worker_disps[i]]); + status = ucp_ep_create(my_thread_info->worker, &ep_params, + &my_thread_info->eps[i]); + if (status != UCS_OK) { + ret = OPAL_ERROR; + } + + status = ucp_ep_rkey_unpack(my_thread_info->eps[i], + &(mem_buf[mem_disps[i] + 3 * sizeof(uint64_t)]), + &(my_thread_info->rkeys[i])); + if (status != UCS_OK) { + ret = OPAL_ERROR; + } + } + } + + pthread_mutex_lock(&active_workers_mutex); + opal_list_append(&active_workers, &my_thread_info->super); + pthread_mutex_unlock(&active_workers_mutex); + + pthread_setspecific(my_thread_key, my_thread_info); + + return ret; +} + END_C_DECLS #endif