diff --git a/contrib/platform/mellanox/optimized b/contrib/platform/mellanox/optimized index f49a0576c64..339518e483e 100644 --- a/contrib/platform/mellanox/optimized +++ b/contrib/platform/mellanox/optimized @@ -1,4 +1,4 @@ -enable_mca_no_build=coll-ml,btl-uct +enable_mca_no_build=coll-ml enable_debug_symbols=yes enable_orterun_prefix_by_default=yes with_verbs=no diff --git a/ompi/mca/osc/ucx/osc_ucx.h b/ompi/mca/osc/ucx/osc_ucx.h index 44dff95a845..7b2b97e910b 100644 --- a/ompi/mca/osc/ucx/osc_ucx.h +++ b/ompi/mca/osc/ucx/osc_ucx.h @@ -24,16 +24,9 @@ #define OMPI_OSC_UCX_ATTACH_MAX 32 #define OMPI_OSC_UCX_RKEY_BUF_MAX 1024 -typedef struct ompi_osc_ucx_win_info { - ucp_rkey_h rkey; - uint64_t addr; - bool rkey_init; -} ompi_osc_ucx_win_info_t; - typedef struct ompi_osc_ucx_component { ompi_osc_base_component_t super; - ucp_context_h ucp_context; - ucp_worker_h ucp_worker; + opal_common_ucx_wpool_t *wpool; bool enable_mpi_threads; opal_free_list_t requests; /* request free list for the r* communication variants */ bool env_initialized; /* UCX environment is initialized or not */ @@ -97,12 +90,10 @@ typedef struct ompi_osc_ucx_state { typedef struct ompi_osc_ucx_module { ompi_osc_base_module_t super; struct ompi_communicator_t *comm; - ucp_mem_h memh; /* remote accessible memory */ int flavor; size_t size; - ucp_mem_h state_memh; - ompi_osc_ucx_win_info_t *win_info_array; - ompi_osc_ucx_win_info_t *state_info_array; + uint64_t *addrs; + uint64_t *state_addrs; int disp_unit; /* if disp_unit >= 0, then everyone has the same * disp unit size; if disp_unit == -1, then we * need to look at disp_units */ @@ -122,6 +113,9 @@ typedef struct ompi_osc_ucx_module { uint64_t req_result; int *start_grp_ranks; bool lock_all_is_nocheck; + opal_common_ucx_ctx_t *ctx; + opal_common_ucx_mem_t *mem; + opal_common_ucx_mem_t *state_mem; } ompi_osc_ucx_module_t; typedef enum locktype { diff --git a/ompi/mca/osc/ucx/osc_ucx_active_target.c b/ompi/mca/osc/ucx/osc_ucx_active_target.c index 3c0a1488eec..be69d209776 100644 --- a/ompi/mca/osc/ucx/osc_ucx_active_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_active_target.c @@ -60,7 +60,7 @@ static inline void ompi_osc_ucx_handle_incoming_post(ompi_osc_ucx_module_t *modu int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - int ret; + int ret = OMPI_SUCCESS; if (module->epoch_type.access != NONE_EPOCH && module->epoch_type.access != FENCE_EPOCH) { @@ -74,7 +74,7 @@ int ompi_osc_ucx_fence(int assert, struct ompi_win_t *win) { } if (!(assert & MPI_MODE_NOPRECEDE)) { - ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/); if (ret != OMPI_SUCCESS) { return ret; } @@ -147,7 +147,7 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t ompi_osc_ucx_handle_incoming_post(module, &(module->state.post_state[i]), ranks_in_win_grp, size); } - ucp_worker_progress(mca_osc_ucx_component.ucp_worker); + opal_common_ucx_workers_progress(mca_osc_ucx_component.wpool); } module->post_count = 0; @@ -163,7 +163,6 @@ int ompi_osc_ucx_start(struct ompi_group_t *group, int assert, struct ompi_win_t int ompi_osc_ucx_complete(struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucs_status_t status; int i, size; int ret = OMPI_SUCCESS; @@ -173,29 +172,30 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) { module->epoch_type.access = NONE_EPOCH; - ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/); if (ret != OMPI_SUCCESS) { return ret; } + module->global_ops_num = 0; memset(module->per_target_ops_nums, 0, sizeof(int) * ompi_comm_size(module->comm)); size = ompi_group_size(module->start_group); for (i = 0; i < size; i++) { - uint64_t remote_addr = (module->state_info_array)[module->start_grp_ranks[i]].addr + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; /* write to state.complete_count on remote side */ - ucp_rkey_h rkey = (module->state_info_array)[module->start_grp_ranks[i]].rkey; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, module->start_grp_ranks[i]); - - status = ucp_atomic_post(ep, UCP_ATOMIC_POST_OP_ADD, 1, - sizeof(uint64_t), remote_addr, rkey); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_atomic_post failed: %d", status); + uint64_t remote_addr = module->state_addrs[module->start_grp_ranks[i]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; // write to state.complete_count on remote side + + ret = opal_common_ucx_mem_post(module->mem, UCP_ATOMIC_POST_OP_ADD, + 1, module->start_grp_ranks[i], sizeof(uint64_t), + remote_addr); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_post failed: %d", ret); } - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); - if (OMPI_SUCCESS != ret) { - OSC_UCX_VERBOSE(1, "opal_common_ucx_ep_flush failed: %d", ret); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, + module->start_grp_ranks[i]); + if (ret != OMPI_SUCCESS) { + return ret; } } @@ -243,25 +243,29 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int assert, struct ompi_win_t } for (i = 0; i < size; i++) { - uint64_t remote_addr = (module->state_info_array)[ranks_in_win_grp[i]].addr + OSC_UCX_STATE_POST_INDEX_OFFSET; /* write to state.post_index on remote side */ - ucp_rkey_h rkey = (module->state_info_array)[ranks_in_win_grp[i]].rkey; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, ranks_in_win_grp[i]); + uint64_t remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_INDEX_OFFSET; // write to state.post_index on remote side uint64_t curr_idx = 0, result = 0; /* do fop first to get an post index */ - opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_FADD, 1, - &result, sizeof(result), - remote_addr, rkey, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_fetch(module->mem, UCP_ATOMIC_FETCH_OP_FADD, + 1, ranks_in_win_grp[i], &result, + sizeof(result), remote_addr); + if (ret != OMPI_SUCCESS) { + return OMPI_ERROR; + } curr_idx = result & (OMPI_OSC_UCX_POST_PEER_MAX - 1); - remote_addr = (module->state_info_array)[ranks_in_win_grp[i]].addr + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof(uint64_t) * curr_idx; + remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_STATE_OFFSET + sizeof(uint64_t) * curr_idx; /* do cas to send post message */ do { - opal_common_ucx_atomic_cswap(ep, 0, (uint64_t)myrank + 1, &result, - sizeof(result), remote_addr, rkey, - mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_cmpswp(module->mem, 0, result, + myrank + 1, &result, sizeof(result), + remote_addr); + if (ret != OMPI_SUCCESS) { + return OMPI_ERROR; + } if (result == 0) break; @@ -302,7 +306,7 @@ int ompi_osc_ucx_wait(struct ompi_win_t *win) { while (module->state.complete_count != (uint64_t)size) { /* not sure if this is required */ - ucp_worker_progress(mca_osc_ucx_component.ucp_worker); + opal_common_ucx_workers_progress(mca_osc_ucx_component.wpool); } module->state.complete_count = 0; diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index ec760d4fda3..b315e281d06 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -66,14 +66,12 @@ static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target, return OMPI_SUCCESS; } -static inline int incr_and_check_ops_num(ompi_osc_ucx_module_t *module, int target, - ucp_ep_h ep) { +static inline int incr_and_check_ops_num(ompi_osc_ucx_module_t *module, int target) { int status; - module->global_ops_num++; module->per_target_ops_nums[target]++; if (module->global_ops_num >= OSC_UCX_OPS_THRESHOLD) { - status = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + status = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (status != OMPI_SUCCESS) { return status; } @@ -137,13 +135,13 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, bool is_origin_contig, ptrdiff_t origin_lb, - int target, ucp_ep_h ep, uint64_t remote_addr, ucp_rkey_h rkey, + int target, uint64_t remote_addr, int target_count, struct ompi_datatype_t *target_dt, bool is_target_contig, ptrdiff_t target_lb, bool is_get) { ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL; uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0; uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0; - ucs_status_t status; + int status; int ret = OMPI_SUCCESS; if (!is_origin_contig) { @@ -164,27 +162,24 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, if (!is_origin_contig && !is_target_contig) { size_t curr_len = 0; + opal_common_ucx_op_t op; while (origin_ucx_iov_idx < origin_ucx_iov_count) { curr_len = MIN(origin_ucx_iov[origin_ucx_iov_idx].len, target_ucx_iov[target_ucx_iov_idx].len); - - if (!is_get) { - status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len, - remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); - return OMPI_ERROR; - } + if (is_get) { + op = OPAL_COMMON_UCX_GET; } else { - status = ucp_get_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len, - remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d",status); - return OMPI_ERROR; - } + op = OPAL_COMMON_UCX_PUT; + } + status = opal_common_ucx_mem_putget(module->mem, op, target, + origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len, + remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr)); + if (OPAL_SUCCESS != status) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status); + return OMPI_ERROR; } - ret = incr_and_check_ops_num(module, target, ep); + ret = incr_and_check_ops_num(module, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -207,26 +202,23 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, } else if (!is_origin_contig) { size_t prev_len = 0; + opal_common_ucx_op_t op; while (origin_ucx_iov_idx < origin_ucx_iov_count) { - if (!is_get) { - status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, - origin_ucx_iov[origin_ucx_iov_idx].len, - remote_addr + target_lb + prev_len, rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); - return OMPI_ERROR; - } + if (is_get) { + op = OPAL_COMMON_UCX_GET; } else { - status = ucp_get_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, - origin_ucx_iov[origin_ucx_iov_idx].len, - remote_addr + target_lb + prev_len, rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status); - return OMPI_ERROR; - } + op = OPAL_COMMON_UCX_PUT; + } + status = opal_common_ucx_mem_putget(module->mem, op, target, + origin_ucx_iov[origin_ucx_iov_idx].addr, + origin_ucx_iov[origin_ucx_iov_idx].len, + remote_addr + target_lb + prev_len); + if (OPAL_SUCCESS != status) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status); + return OMPI_ERROR; } - ret = incr_and_check_ops_num(module, target, ep); + ret = incr_and_check_ops_num(module, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -236,26 +228,23 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, } } else { size_t prev_len = 0; + opal_common_ucx_op_t op; while (target_ucx_iov_idx < target_ucx_iov_count) { - if (!is_get) { - status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb + prev_len), - target_ucx_iov[target_ucx_iov_idx].len, - remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); - return OMPI_ERROR; - } + if (is_get) { + op = OPAL_COMMON_UCX_GET; } else { - status = ucp_get_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb + prev_len), - target_ucx_iov[target_ucx_iov_idx].len, - remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status); - return OMPI_ERROR; - } + op = OPAL_COMMON_UCX_PUT; } - ret = incr_and_check_ops_num(module, target, ep); + status = opal_common_ucx_mem_putget(module->mem, op, target, + (void *)((intptr_t)origin_addr + origin_lb + prev_len), + target_ucx_iov[target_ucx_iov_idx].len, + remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr)); + if (OPAL_SUCCESS != status) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status); + return OMPI_ERROR; + } + ret = incr_and_check_ops_num(module, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -275,46 +264,47 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, return ret; } -static inline int start_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, int target) { +static inline int start_atomicity(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = -1; - ucp_rkey_h rkey = (module->state_info_array)[target].rkey; - uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_ACC_LOCK_OFFSET; - ucs_status_t status; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; + int ret = OMPI_SUCCESS; while (result_value != TARGET_LOCK_UNLOCKED) { - status = opal_common_ucx_atomic_cswap(ep, TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, - &result_value, sizeof(result_value), - remote_addr, rkey, - mca_osc_ucx_component.ucp_worker); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_atomic_cswap64 failed: %d", status); + ret = opal_common_ucx_mem_cmpswp(module->state_mem, + TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, + target, &result_value, sizeof(result_value), + remote_addr); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_cmpswp failed: %d", ret); return OMPI_ERROR; } } - return OMPI_SUCCESS; + return ret; } -static inline int end_atomicity(ompi_osc_ucx_module_t *module, ucp_ep_h ep, int target) { +static inline int end_atomicity(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = 0; - ucp_rkey_h rkey = (module->state_info_array)[target].rkey; - uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_ACC_LOCK_OFFSET; - int ret; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; + int ret = OMPI_SUCCESS; - ret = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, - &result_value, sizeof(result_value), - remote_addr, rkey, mca_osc_ucx_component.ucp_worker); - if (OMPI_SUCCESS != ret) { - return ret; + ret = opal_common_ucx_mem_fetch(module->state_mem, + UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + target, &result_value, sizeof(result_value), + remote_addr); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fetch failed: %d", ret); + return OMPI_ERROR; } assert(result_value == TARGET_LOCK_EXCLUSIVE); - return OMPI_SUCCESS; + return ret; } static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module, ucp_ep_h ep, int target) { +/* ucp_rkey_h state_rkey = (module->state_info_array)[target].rkey; uint64_t remote_state_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET; size_t len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX; @@ -361,18 +351,17 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module free(temp_buf); return status; + */ + return OMPI_SUCCESS; } int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); - ucp_rkey_h rkey; + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); bool is_origin_contig = false, is_target_contig = false; ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; - ucs_status_t status; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -380,21 +369,17 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data return ret; } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { +/* if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { return OMPI_ERROR; } - } - - CHECK_VALID_RKEY(module, target, target_count); + } */ if (!target_count) { return OMPI_SUCCESS; } - rkey = (module->win_info_array[target]).rkey; - ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent); @@ -408,16 +393,17 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data ompi_datatype_type_size(origin_dt, &origin_len); origin_len *= origin_count; - status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len, - remote_addr + target_lb, rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); + ret = opal_common_ucx_mem_putget(module->mem, OPAL_COMMON_UCX_PUT, target, + (void *)((intptr_t)origin_addr + origin_lb), + origin_len, remote_addr + target_lb); + if (OPAL_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); return OMPI_ERROR; } - return incr_and_check_ops_num(module, target, ep); + return incr_and_check_ops_num(module, target); } else { return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig, - origin_lb, target, ep, remote_addr, rkey, target_count, target_dt, + origin_lb, target, remote_addr, target_count, target_dt, is_target_contig, target_lb, false); } } @@ -427,12 +413,9 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); - ucp_rkey_h rkey; + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; bool is_origin_contig = false, is_target_contig = false; - ucs_status_t status; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -440,20 +423,17 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, return ret; } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { +/* if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { return OMPI_ERROR; } - } - - CHECK_VALID_RKEY(module, target, target_count); + } */ if (!target_count) { return OMPI_SUCCESS; } - rkey = (module->win_info_array[target]).rkey; ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent); @@ -468,17 +448,18 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, ompi_datatype_type_size(origin_dt, &origin_len); origin_len *= origin_count; - status = ucp_get_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len, - remote_addr + target_lb, rkey); - if (status != UCS_OK && status != UCS_INPROGRESS) { - OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status); + ret = opal_common_ucx_mem_putget(module->mem, OPAL_COMMON_UCX_GET, target, + (void *)((intptr_t)origin_addr + origin_lb), + origin_len, remote_addr + target_lb); + if (OPAL_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); return OMPI_ERROR; } - return incr_and_check_ops_num(module, target, ep); + return incr_and_check_ops_num(module, target); } else { return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig, - origin_lb, target, ep, remote_addr, rkey, target_count, target_dt, + origin_lb, target, remote_addr, target_count, target_dt, is_target_contig, target_lb, true); } } @@ -489,7 +470,6 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -501,7 +481,7 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, return ret; } - ret = start_atomicity(module, ep, target); + ret = start_atomicity(module, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -541,7 +521,7 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, return ret; } - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -595,7 +575,7 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, return ret; } - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -603,9 +583,7 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, free(temp_addr_holder); } - ret = end_atomicity(module, ep, target); - - return ret; + return end_atomicity(module, target); } int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr, @@ -613,47 +591,41 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a int target, ptrdiff_t target_disp, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); - ucp_rkey_h rkey; + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); size_t dt_bytes; - ompi_osc_ucx_internal_request_t *req = NULL; int ret = OMPI_SUCCESS; - ucs_status_t status; ret = check_sync_state(module, target, false); if (ret != OMPI_SUCCESS) { return ret; } - ret = start_atomicity(module, ep, target); + ret = start_atomicity(module, target); if (ret != OMPI_SUCCESS) { return ret; } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { +/* if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { return OMPI_ERROR; } - } - - rkey = (module->win_info_array[target]).rkey; + } */ ompi_datatype_type_size(dt, &dt_bytes); - memcpy(result_addr, origin_addr, dt_bytes); - req = ucp_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_CSWAP, *(uint64_t *)compare_addr, - result_addr, dt_bytes, remote_addr, rkey, req_completion); - if (UCS_PTR_IS_PTR(req)) { - ucp_request_release(req); + ret = opal_common_ucx_mem_cmpswp(module->mem,*(uint64_t *)compare_addr, + *(uint64_t *)origin_addr, target, + result_addr, dt_bytes, remote_addr); + if (ret != OMPI_SUCCESS) { + return ret; } - ret = incr_and_check_ops_num(module, target, ep); + ret = incr_and_check_ops_num(module, target); if (ret != OMPI_SUCCESS) { return ret; } - return end_atomicity(module, ep, target); + return end_atomicity(module, target); } int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, @@ -670,28 +642,22 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, if (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op || op == &ompi_mpi_op_sum.op) { - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); - ucp_rkey_h rkey; + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); uint64_t value = *(uint64_t *)origin_addr; ucp_atomic_fetch_op_t opcode; size_t dt_bytes; - ompi_osc_ucx_internal_request_t *req = NULL; - ucs_status_t status; - ret = start_atomicity(module, ep, target); + ret = start_atomicity(module, target); if (ret != OMPI_SUCCESS) { return ret; } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { +/* if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { return OMPI_ERROR; } - } - - rkey = (module->win_info_array[target]).rkey; + } */ ompi_datatype_type_size(dt, &dt_bytes); @@ -704,18 +670,18 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, } } - req = ucp_atomic_fetch_nb(ep, opcode, value, result_addr, - dt_bytes, remote_addr, rkey, req_completion); - if (UCS_PTR_IS_PTR(req)) { - ucp_request_release(req); + ret = opal_common_ucx_mem_fetch(module->mem, opcode, value, target, + (void *)origin_addr, dt_bytes, remote_addr); + if (ret != OMPI_SUCCESS) { + return ret; } - ret = incr_and_check_ops_num(module, target, ep); + ret = incr_and_check_ops_num(module, target); if (ret != OMPI_SUCCESS) { return ret; } - return end_atomicity(module, ep, target); + return end_atomicity(module, target); } else { return ompi_osc_ucx_get_accumulate(origin_addr, 1, dt, result_addr, 1, dt, target, target_disp, 1, dt, op, win); @@ -730,7 +696,6 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -738,7 +703,7 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, return ret; } - ret = start_atomicity(module, ep, target); + ret = start_atomicity(module, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -786,7 +751,7 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, return ret; } - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -839,7 +804,7 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, return ret; } - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -848,9 +813,7 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, } } - ret = end_atomicity(module, ep, target); - - return ret; + return end_atomicity(module, target); } int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, @@ -859,12 +822,9 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win, struct ompi_request_t **request) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - uint64_t remote_addr = (module->state_info_array[target]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET; - ucp_rkey_h rkey; + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); ompi_osc_ucx_request_t *ucx_req = NULL; ompi_osc_ucx_internal_request_t *internal_req = NULL; - ucs_status_t status; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, true); @@ -872,16 +832,12 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, return ret; } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { +/* if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { return OMPI_ERROR; } - } - - CHECK_VALID_RKEY(module, target, target_count); - - rkey = (module->win_info_array[target]).rkey; + } */ OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); assert(NULL != ucx_req); @@ -892,15 +848,19 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, return ret; } - status = ucp_worker_fence(mca_osc_ucx_component.ucp_worker); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_worker_fence failed: %d", status); + ret = opal_common_ucx_mem_fence(module->mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); return OMPI_ERROR; } - internal_req = ucp_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_FADD, 0, - &(module->req_result), sizeof(uint64_t), - remote_addr, rkey, req_completion); + ret = opal_common_ucx_mem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD, + 0, target, &(module->req_result), + sizeof(uint64_t), remote_addr, + (ucs_status_ptr_t *)&internal_req); + if (ret != OMPI_SUCCESS) { + return ret; + } if (UCS_PTR_IS_PTR(internal_req)) { internal_req->external_req = ucx_req; @@ -911,7 +871,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, *request = &ucx_req->super; - return incr_and_check_ops_num(module, target, ep); + return incr_and_check_ops_num(module, target); } int ompi_osc_ucx_rget(void *origin_addr, int origin_count, @@ -920,12 +880,9 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win, struct ompi_request_t **request) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - uint64_t remote_addr = (module->state_info_array[target]).addr + OSC_UCX_STATE_REQ_FLAG_OFFSET; - ucp_rkey_h rkey; + uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); ompi_osc_ucx_request_t *ucx_req = NULL; ompi_osc_ucx_internal_request_t *internal_req = NULL; - ucs_status_t status; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, true); @@ -933,16 +890,12 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, return ret; } - if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { +/* if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { status = get_dynamic_win_info(remote_addr, module, ep, target); if (status != UCS_OK) { return OMPI_ERROR; } - } - - CHECK_VALID_RKEY(module, target, target_count); - - rkey = (module->win_info_array[target]).rkey; + } */ OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); assert(NULL != ucx_req); @@ -953,15 +906,19 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, return ret; } - status = ucp_worker_fence(mca_osc_ucx_component.ucp_worker); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_worker_fence failed: %d", status); + ret = opal_common_ucx_mem_fence(module->mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); return OMPI_ERROR; } - internal_req = ucp_atomic_fetch_nb(ep, UCP_ATOMIC_FETCH_OP_FADD, 0, - &(module->req_result), sizeof(uint64_t), - remote_addr, rkey, req_completion); + ret = opal_common_ucx_mem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD, + 0, target, &(module->req_result), + sizeof(uint64_t), remote_addr, + (ucs_status_ptr_t *)&internal_req); + if (ret != OMPI_SUCCESS) { + return ret; + } if (UCS_PTR_IS_PTR(internal_req)) { internal_req->external_req = ucx_req; @@ -972,7 +929,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, *request = &ucx_req->super; - return incr_and_check_ops_num(module, target, ep); + return incr_and_check_ops_num(module, target); } int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count, diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 6fd3291bad0..793ebdb763d 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -54,8 +54,7 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = { .osc_select = component_select, .osc_finalize = component_finalize, }, - .ucp_context = NULL, - .ucp_worker = NULL, + .wpool = NULL, .env_initialized = false, .num_incomplete_req_ops = 0, .num_modules = 0 @@ -120,37 +119,22 @@ static int component_register(void) { } static int progress_callback(void) { - ucp_worker_progress(mca_osc_ucx_component.ucp_worker); + if (mca_osc_ucx_component.wpool != NULL) { + ucp_worker_progress(mca_osc_ucx_component.wpool->recv_worker); + } return 0; } static int component_init(bool enable_progress_threads, bool enable_mpi_threads) { mca_osc_ucx_component.enable_mpi_threads = enable_mpi_threads; - + mca_osc_ucx_component.wpool = opal_common_ucx_wpool_allocate(); opal_common_ucx_mca_register(); return OMPI_SUCCESS; } static int component_finalize(void) { - int i; - for (i = 0; i < ompi_proc_world_size(); i++) { - ucp_ep_h ep = OSC_UCX_GET_EP(&(ompi_mpi_comm_world.comm), i); - if (ep != NULL) { - ucp_ep_destroy(ep); - } - } - - if (mca_osc_ucx_component.ucp_worker != NULL) { - ucp_worker_destroy(mca_osc_ucx_component.ucp_worker); - } - - assert(mca_osc_ucx_component.num_incomplete_req_ops == 0); - if (mca_osc_ucx_component.env_initialized == true) { - OBJ_DESTRUCT(&mca_osc_ucx_component.requests); - ucp_cleanup(mca_osc_ucx_component.ucp_context); - mca_osc_ucx_component.env_initialized = false; - } opal_common_ucx_mca_deregister(); + opal_common_ucx_wpool_free(mca_osc_ucx_component.wpool); return OMPI_SUCCESS; } @@ -160,9 +144,11 @@ static int component_query(struct ompi_win_t *win, void **base, size_t size, int return mca_osc_ucx_component.priority; } -static inline int allgather_len_and_info(void *my_info, int my_info_len, char **recv_info, - int *disps, struct ompi_communicator_t *comm) { +static int exchange_len_info(void *my_info, size_t my_info_len, char **recv_info_ptr, + int **disps_ptr, void *metadata) +{ int ret = OMPI_SUCCESS; + struct ompi_communicator_t *comm = (struct ompi_communicator_t *)metadata; int comm_size = ompi_comm_size(comm); int lens[comm_size]; int total_len, i; @@ -175,15 +161,15 @@ static inline int allgather_len_and_info(void *my_info, int my_info_len, char ** } total_len = 0; + (*disps_ptr) = (int *)calloc(comm_size, sizeof(int)); for (i = 0; i < comm_size; i++) { - disps[i] = total_len; + (*disps_ptr)[i] = total_len; total_len += lens[i]; } - (*recv_info) = (char *)malloc(total_len); - + (*recv_info_ptr) = (char *)calloc(total_len, sizeof(char)); ret = comm->c_coll->coll_allgatherv(my_info, my_info_len, MPI_BYTE, - (void *)(*recv_info), lens, disps, MPI_BYTE, + (void *)(*recv_info_ptr), lens, (*disps_ptr), MPI_BYTE, comm, comm->c_coll->coll_allgatherv_module); if (OMPI_SUCCESS != ret) { return ret; @@ -192,60 +178,6 @@ static inline int allgather_len_and_info(void *my_info, int my_info_len, char ** return ret; } -static inline int mem_map(void **base, size_t size, ucp_mem_h *memh_ptr, - ompi_osc_ucx_module_t *module, int flavor) { - ucp_mem_map_params_t mem_params; - ucp_mem_attr_t mem_attrs; - ucs_status_t status; - int ret = OMPI_SUCCESS; - - if (!(flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE) - || size == 0) { - return ret; - } - - memset(&mem_params, 0, sizeof(ucp_mem_map_params_t)); - mem_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | - UCP_MEM_MAP_PARAM_FIELD_LENGTH | - UCP_MEM_MAP_PARAM_FIELD_FLAGS; - mem_params.length = size; - if (flavor == MPI_WIN_FLAVOR_ALLOCATE) { - mem_params.address = NULL; - mem_params.flags = UCP_MEM_MAP_ALLOCATE; - } else { - mem_params.address = (*base); - } - - /* memory map */ - - status = ucp_mem_map(mca_osc_ucx_component.ucp_context, &mem_params, memh_ptr); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_mem_map failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - - mem_attrs.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | UCP_MEM_ATTR_FIELD_LENGTH; - status = ucp_mem_query((*memh_ptr), &mem_attrs); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_mem_query failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - - assert(mem_attrs.length >= size); - if (flavor == MPI_WIN_FLAVOR_CREATE) { - assert(mem_attrs.address == (*base)); - } else { - (*base) = mem_attrs.address; - } - - return ret; - error: - ucp_mem_unmap(mca_osc_ucx_component.ucp_context, (*memh_ptr)); - return ret; -} - static void ompi_osc_ucx_unregister_progress() { int ret; @@ -267,23 +199,14 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in char *name = NULL; long values[2]; int ret = OMPI_SUCCESS; - ucs_status_t status; + //ucs_status_t status; int i, comm_size = ompi_comm_size(comm); - int is_eps_ready; - bool eps_created = false, env_initialized = false; - ucp_address_t *my_addr = NULL; - size_t my_addr_len; - char *recv_buf = NULL; - void *rkey_buffer = NULL, *state_rkey_buffer = NULL; - size_t rkey_buffer_size, state_rkey_buffer_size; + bool env_initialized = false; void *state_base = NULL; - void * my_info = NULL; - size_t my_info_len; - int disps[comm_size]; - int rkey_sizes[comm_size]; + opal_common_ucx_mem_type_t mem_type; uint64_t zero = 0; - size_t info_offset; - uint64_t size_u64; + void * my_info = NULL; + char *recv_buf = NULL; /* the osc/sm component is the exclusive provider for support for * shared memory windows */ @@ -292,16 +215,6 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in } if (mca_osc_ucx_component.env_initialized == false) { - ucp_config_t *config = NULL; - ucp_params_t context_params; - ucp_worker_params_t worker_params; - ucp_worker_attr_t worker_attr; - - status = ucp_config_read("MPI", NULL, &config); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_config_read failed: %d", status); - return OMPI_ERROR; - } OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t); ret = opal_free_list_init (&mca_osc_ucx_component.requests, @@ -314,57 +227,16 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in goto error; } - /* initialize UCP context */ - - memset(&context_params, 0, sizeof(context_params)); - context_params.field_mask = UCP_PARAM_FIELD_FEATURES | - UCP_PARAM_FIELD_MT_WORKERS_SHARED | - UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | - UCP_PARAM_FIELD_REQUEST_INIT | - UCP_PARAM_FIELD_REQUEST_SIZE; - context_params.features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64; - context_params.mt_workers_shared = 0; - context_params.estimated_num_eps = ompi_proc_world_size(); - context_params.request_init = internal_req_init; - context_params.request_size = sizeof(ompi_osc_ucx_internal_request_t); - - status = ucp_init(&context_params, config, &mca_osc_ucx_component.ucp_context); - ucp_config_release(config); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_init failed: %d", status); - ret = OMPI_ERROR; + ret = opal_common_ucx_wpool_init(mca_osc_ucx_component.wpool, + ompi_proc_world_size(), + internal_req_init, + sizeof(ompi_osc_ucx_internal_request_t), + mca_osc_ucx_component.enable_mpi_threads); + if (OMPI_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_wpool_init failed: %d", ret); goto error; } - assert(mca_osc_ucx_component.ucp_worker == NULL); - memset(&worker_params, 0, sizeof(worker_params)); - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = (mca_osc_ucx_component.enable_mpi_threads == true) - ? UCS_THREAD_MODE_MULTI : UCS_THREAD_MODE_SINGLE; - status = ucp_worker_create(mca_osc_ucx_component.ucp_context, &worker_params, - &(mca_osc_ucx_component.ucp_worker)); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_worker_create failed: %d", status); - ret = OMPI_ERROR; - goto error_nomem; - } - - /* query UCP worker attributes */ - worker_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; - status = ucp_worker_query(mca_osc_ucx_component.ucp_worker, &worker_attr); - if (UCS_OK != status) { - OSC_UCX_VERBOSE(1, "ucp_worker_query failed: %d", status); - ret = OMPI_ERROR; - goto error_nomem; - } - - if (mca_osc_ucx_component.enable_mpi_threads == true && - worker_attr.thread_mode != UCS_THREAD_MODE_MULTI) { - OSC_UCX_VERBOSE(1, "ucx does not support multithreading"); - ret = OMPI_ERROR; - goto error_nomem; - } - mca_osc_ucx_component.env_initialized = true; env_initialized = true; } @@ -425,187 +297,72 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in } } - /* exchange endpoints if necessary */ - is_eps_ready = 1; - for (i = 0; i < comm_size; i++) { - if (OSC_UCX_GET_EP(module->comm, i) == NULL) { - is_eps_ready = 0; - break; - } - } - - ret = module->comm->c_coll->coll_allreduce(MPI_IN_PLACE, &is_eps_ready, 1, MPI_INT, - MPI_LAND, - module->comm, - module->comm->c_coll->coll_allreduce_module); + ret = opal_common_ucx_ctx_create(mca_osc_ucx_component.wpool, comm_size, + &exchange_len_info, (void *)module->comm, + &module->ctx); if (OMPI_SUCCESS != ret) { goto error; } - if (!is_eps_ready) { - status = ucp_worker_get_address(mca_osc_ucx_component.ucp_worker, - &my_addr, &my_addr_len); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_worker_get_address failed: %d", status); - ret = OMPI_ERROR; - goto error; + if (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE) { + switch (flavor) { + case MPI_WIN_FLAVOR_ALLOCATE: + mem_type = OPAL_COMMON_UCX_MEM_ALLOCATE_MAP; + break; + case MPI_WIN_FLAVOR_CREATE: + mem_type = OPAL_COMMON_UCX_MEM_MAP; + break; } - ret = allgather_len_and_info(my_addr, (int)my_addr_len, - &recv_buf, disps, module->comm); + ret = opal_common_ucx_mem_create(module->ctx, comm_size, base, size, + mem_type, &exchange_len_info, + (void *)module->comm, &module->mem); if (ret != OMPI_SUCCESS) { goto error; } - for (i = 0; i < comm_size; i++) { - if (OSC_UCX_GET_EP(module->comm, i) == NULL) { - ucp_ep_params_t ep_params; - ucp_ep_h ep; - memset(&ep_params, 0, sizeof(ucp_ep_params_t)); - ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = (ucp_address_t *)&(recv_buf[disps[i]]); - status = ucp_ep_create(mca_osc_ucx_component.ucp_worker, &ep_params, &ep); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_ep_create failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - - ompi_comm_peer_lookup(module->comm, i)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_UCX] = ep; - } - } - - ucp_worker_release_address(mca_osc_ucx_component.ucp_worker, my_addr); - my_addr = NULL; - free(recv_buf); - recv_buf = NULL; - - eps_created = true; - } - - ret = mem_map(base, size, &(module->memh), module, flavor); - if (ret != OMPI_SUCCESS) { - goto error; } state_base = (void *)&(module->state); - ret = mem_map(&state_base, sizeof(ompi_osc_ucx_state_t), &(module->state_memh), - module, MPI_WIN_FLAVOR_CREATE); + ret = opal_common_ucx_mem_create(module->ctx, comm_size, &state_base, + sizeof(ompi_osc_ucx_state_t), + OPAL_COMMON_UCX_MEM_MAP, &exchange_len_info, + (void *)module->comm, &module->state_mem); if (ret != OMPI_SUCCESS) { goto error; } - module->win_info_array = calloc(comm_size, sizeof(ompi_osc_ucx_win_info_t)); - if (module->win_info_array == NULL) { - ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE; - goto error; - } - - module->state_info_array = calloc(comm_size, sizeof(ompi_osc_ucx_win_info_t)); - if (module->state_info_array == NULL) { - ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE; - goto error; - } - - if (size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) { - status = ucp_rkey_pack(mca_osc_ucx_component.ucp_context, module->memh, - &rkey_buffer, &rkey_buffer_size); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_rkey_pack failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - } else { - rkey_buffer_size = 0; - } - - status = ucp_rkey_pack(mca_osc_ucx_component.ucp_context, module->state_memh, - &state_rkey_buffer, &state_rkey_buffer_size); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_rkey_pack failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - - size_u64 = (uint64_t)size; - my_info_len = 3 * sizeof(uint64_t) + rkey_buffer_size + state_rkey_buffer_size; - my_info = malloc(my_info_len); + /* exchange window addrs */ + my_info = malloc(2 * sizeof(uint64_t)); if (my_info == NULL) { ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE; goto error; } - info_offset = 0; - if (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE) { - memcpy_off(my_info, base, sizeof(uint64_t), info_offset); + memcpy(my_info, base, sizeof(uint64_t)); } else { - memcpy_off(my_info, &zero, sizeof(uint64_t), info_offset); + memcpy(my_info, &zero, sizeof(uint64_t)); } - memcpy_off(my_info, &state_base, sizeof(uint64_t), info_offset); - memcpy_off(my_info, &size_u64, sizeof(uint64_t), info_offset); - memcpy_off(my_info, rkey_buffer, rkey_buffer_size, info_offset); - memcpy_off(my_info, state_rkey_buffer, state_rkey_buffer_size, info_offset); + memcpy((char*)my_info + sizeof(uint64_t), &state_base, sizeof(uint64_t)); - assert(my_info_len == info_offset); - - ret = allgather_len_and_info(my_info, (int)my_info_len, &recv_buf, disps, module->comm); + recv_buf = (char *)calloc(comm_size, 2 * sizeof(uint64_t)); + ret = comm->c_coll->coll_allgather((void *)my_info, 2 * sizeof(uint64_t), + MPI_BYTE, recv_buf, 2 * sizeof(uint64_t), + MPI_BYTE, comm, comm->c_coll->coll_allgather_module); if (ret != OMPI_SUCCESS) { goto error; } - ret = comm->c_coll->coll_allgather((void *)&rkey_buffer_size, 1, MPI_INT, - rkey_sizes, 1, MPI_INT, comm, - comm->c_coll->coll_allgather_module); - if (OMPI_SUCCESS != ret) { - goto error; - } - + module->addrs = calloc(comm_size, sizeof(uint64_t)); + module->state_addrs = calloc(comm_size, sizeof(uint64_t)); for (i = 0; i < comm_size; i++) { - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, i); - uint64_t dest_size; - assert(ep != NULL); - - info_offset = disps[i]; - - memcpy(&(module->win_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t)); - info_offset += sizeof(uint64_t); - memcpy(&(module->state_info_array[i]).addr, &recv_buf[info_offset], sizeof(uint64_t)); - info_offset += sizeof(uint64_t); - memcpy(&dest_size, &recv_buf[info_offset], sizeof(uint64_t)); - info_offset += sizeof(uint64_t); - - (module->win_info_array[i]).rkey_init = false; - if (dest_size > 0 && (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE)) { - status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset], - &((module->win_info_array[i]).rkey)); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - info_offset += rkey_sizes[i]; - (module->win_info_array[i]).rkey_init = true; - } - - status = ucp_ep_rkey_unpack(ep, &recv_buf[info_offset], - &((module->state_info_array[i]).rkey)); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status); - ret = OMPI_ERROR; - goto error; - } - (module->state_info_array[i]).rkey_init = true; + memcpy(&(module->addrs[i]), recv_buf + i * 2 * sizeof(uint64_t), sizeof(uint64_t)); + memcpy(&(module->state_addrs[i]), recv_buf + i * 2 * sizeof(uint64_t) + sizeof(uint64_t), sizeof(uint64_t)); } - - free(my_info); free(recv_buf); - if (rkey_buffer_size != 0) { - ucp_rkey_buffer_release(rkey_buffer); - } - ucp_rkey_buffer_release(state_rkey_buffer); - + /* init window state */ module->state.lock = TARGET_LOCK_UNLOCKED; module->state.post_index = 0; memset((void *)module->state.post_state, 0, sizeof(uint64_t) * OMPI_OSC_UCX_POST_PEER_MAX); @@ -655,30 +412,9 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in return ret; error: - if (my_addr) ucp_worker_release_address(mca_osc_ucx_component.ucp_worker, my_addr); - if (recv_buf) free(recv_buf); - if (my_info) free(my_info); - for (i = 0; i < comm_size; i++) { - if ((module->win_info_array[i]).rkey != NULL) { - ucp_rkey_destroy((module->win_info_array[i]).rkey); - } - if ((module->state_info_array[i]).rkey != NULL) { - ucp_rkey_destroy((module->state_info_array[i]).rkey); - } - } - if (rkey_buffer) ucp_rkey_buffer_release(rkey_buffer); - if (state_rkey_buffer) ucp_rkey_buffer_release(state_rkey_buffer); - if (module->win_info_array) free(module->win_info_array); - if (module->state_info_array) free(module->state_info_array); if (module->disp_units) free(module->disp_units); if (module->comm) ompi_comm_free(&module->comm); if (module->per_target_ops_nums) free(module->per_target_ops_nums); - if (eps_created) { - for (i = 0; i < comm_size; i++) { - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, i); - ucp_ep_destroy(ep); - } - } if (module) { free(module); ompi_osc_ucx_unregister_progress(); @@ -686,9 +422,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in error_nomem: if (env_initialized == true) { + opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool); OBJ_DESTRUCT(&mca_osc_ucx_component.requests); - ucp_worker_destroy(mca_osc_ucx_component.ucp_worker); - ucp_cleanup(mca_osc_ucx_component.ucp_context); mca_osc_ucx_component.env_initialized = false; } return ret; @@ -716,6 +451,7 @@ int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_ } int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { +/* ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int insert_index = -1, contain_index; void *rkey_buffer; @@ -746,8 +482,8 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { (OMPI_OSC_UCX_ATTACH_MAX - (insert_index + 1)) * sizeof(ompi_osc_dynamic_win_info_t)); } else { insert_index = 0; - } - + }*/ +/* ret = mem_map(&base, len, &(module->local_dynamic_win_info[insert_index].memh), module, MPI_WIN_FLAVOR_CREATE); if (ret != OMPI_SUCCESS) { @@ -775,9 +511,12 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { ucp_rkey_buffer_release(rkey_buffer); return ret; + */ + return OMPI_SUCCESS; } int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base) { +/* ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int insert, contain; @@ -788,7 +527,7 @@ int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base) { (uint64_t)base, 1, &insert); assert(contain >= 0 && (uint64_t)contain < module->state.dynamic_win_count); - /* if we can't find region - just exit */ + // if we can't find region - just exit if (contain < 0) { return OMPI_SUCCESS; } @@ -808,11 +547,16 @@ int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base) { } return OMPI_SUCCESS; + */ + return OMPI_SUCCESS; } int ompi_osc_ucx_free(struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - int i, ret; + int ret; + + DBG_OUT("ompi_osc_ucx_free: start, mem = %p lock flag = %d\n", + (void *)module->mem, (int)module->state.lock); assert(module->global_ops_num == 0); assert(module->lock_count == 0); @@ -820,19 +564,23 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) { OBJ_DESTRUCT(&module->outstanding_locks); OBJ_DESTRUCT(&module->pending_posts); - while (module->state.lock != TARGET_LOCK_UNLOCKED) { - /* not sure if this is required */ - ucp_worker_progress(mca_osc_ucx_component.ucp_worker); - } + opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0); - ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker); - if (OMPI_SUCCESS != ret) { - OSC_UCX_VERBOSE(1, "opal_common_ucx_worker_flush failed: %d", ret); + DBG_OUT("ompi_osc_ucx_free: after mem_flush, mem = %p lock flag = %d\n", + (void *)module->mem, (int)module->state.lock); + + /* + while (module->state.lock != TARGET_LOCK_UNLOCKED) { + ucp_worker_progress(mca_osc_ucx_component.wpool->recv_worker); } + */ ret = module->comm->c_coll->coll_barrier(module->comm, module->comm->c_coll->coll_barrier_module); + DBG_OUT("ompi_osc_ucx_free: after barrier, mem = %p\n", (void *)module->mem); + +/* for (i = 0; i < ompi_comm_size(module->comm); i++) { if ((module->win_info_array[i]).rkey_init == true) { ucp_rkey_destroy((module->win_info_array[i]).rkey); @@ -843,19 +591,24 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) { free(module->win_info_array); free(module->state_info_array); - free(module->per_target_ops_nums); - if ((module->flavor == MPI_WIN_FLAVOR_ALLOCATE || module->flavor == MPI_WIN_FLAVOR_CREATE) && module->size > 0) { ucp_mem_unmap(mca_osc_ucx_component.ucp_context, module->memh); } ucp_mem_unmap(mca_osc_ucx_component.ucp_context, module->state_memh); + + return ret; + */ + free(module->per_target_ops_nums); + + opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool); + if (module->disp_units) free(module->disp_units); ompi_comm_free(&module->comm); free(module); ompi_osc_ucx_unregister_progress(); - return ret; + return OMPI_SUCCESS; } diff --git a/ompi/mca/osc/ucx/osc_ucx_passive_target.c b/ompi/mca/osc/ucx/osc_ucx_passive_target.c index 3a7ad3e9e24..2c658a38cd4 100644 --- a/ompi/mca/osc/ucx/osc_ucx_passive_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_passive_target.c @@ -20,88 +20,75 @@ OBJ_CLASS_INSTANCE(ompi_osc_ucx_lock_t, opal_object_t, NULL, NULL); static inline int start_shared(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = -1; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - ucp_rkey_h rkey = (module->state_info_array)[target].rkey; - uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_LOCK_OFFSET; - ucs_status_t status; - int ret; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + int ret = OMPI_SUCCESS; while (true) { - ret = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_FADD, 1, - &result_value, sizeof(result_value), - remote_addr, rkey, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_fetch(module->state_mem, UCP_ATOMIC_FETCH_OP_FADD, 1, + target, &result_value, sizeof(result_value), + remote_addr); if (OMPI_SUCCESS != ret) { return ret; } + + DBG_OUT("start_shared: after fadd, result_value = %d", (int)result_value); + assert((int64_t)result_value >= 0); if (result_value >= TARGET_LOCK_EXCLUSIVE) { - status = ucp_atomic_post(ep, UCP_ATOMIC_POST_OP_ADD, (-1), sizeof(uint64_t), - remote_addr, rkey); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_atomic_add64 failed: %d", status); - return OMPI_ERROR; + ret = opal_common_ucx_mem_post(module->state_mem, + UCP_ATOMIC_POST_OP_ADD, (-1), target, + sizeof(uint64_t), remote_addr); + if (OMPI_SUCCESS != ret) { + return ret; } } else { break; } } - return OMPI_SUCCESS; + return ret; } static inline int end_shared(ompi_osc_ucx_module_t *module, int target) { - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - ucp_rkey_h rkey = (module->state_info_array)[target].rkey; - uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_LOCK_OFFSET; - ucs_status_t status; - - status = ucp_atomic_post(ep, UCP_ATOMIC_POST_OP_ADD, (-1), sizeof(uint64_t), - remote_addr, rkey); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_atomic_post(OP_ADD) failed: %d", status); - return OMPI_ERROR; - } - - return OMPI_SUCCESS; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + return opal_common_ucx_mem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD, + (-1), target, sizeof(uint64_t), remote_addr); } static inline int start_exclusive(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = -1; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - ucp_rkey_h rkey = (module->state_info_array)[target].rkey; - uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_LOCK_OFFSET; - ucs_status_t status; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + int ret = OMPI_SUCCESS; while (result_value != TARGET_LOCK_UNLOCKED) { - status = opal_common_ucx_atomic_cswap(ep, TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, - &result_value, sizeof(result_value), - remote_addr, rkey, - mca_osc_ucx_component.ucp_worker); - if (status != UCS_OK) { - return OMPI_ERROR; + ret = opal_common_ucx_mem_cmpswp(module->state_mem, + TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, + target, &result_value, sizeof(result_value), + remote_addr); + if (OMPI_SUCCESS != ret) { + return ret; } } - return OMPI_SUCCESS; + return ret; } static inline int end_exclusive(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = 0; - ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); - ucp_rkey_h rkey = (module->state_info_array)[target].rkey; - uint64_t remote_addr = (module->state_info_array)[target].addr + OSC_UCX_STATE_LOCK_OFFSET; - int ret; - - ret = opal_common_ucx_atomic_fetch(ep, UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, - &result_value, sizeof(result_value), - remote_addr, rkey, mca_osc_ucx_component.ucp_worker); + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + int ret = OMPI_SUCCESS; + + ret = opal_common_ucx_mem_fetch(module->state_mem, + UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + target, &result_value, sizeof(result_value), + remote_addr); if (OMPI_SUCCESS != ret) { return ret; } assert(result_value >= TARGET_LOCK_EXCLUSIVE); - return OMPI_SUCCESS; + return ret; } int ompi_osc_ucx_lock(int lock_type, int target, int assert, struct ompi_win_t *win) { @@ -158,7 +145,6 @@ int ompi_osc_ucx_unlock(int target, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; ompi_osc_ucx_lock_t *lock = NULL; int ret = OMPI_SUCCESS; - ucp_ep_h ep; if (module->epoch_type.access != PASSIVE_EPOCH) { return OMPI_ERR_RMA_SYNC; @@ -172,8 +158,7 @@ int ompi_osc_ucx_unlock(int target, struct ompi_win_t *win) { opal_hash_table_remove_value_uint32(&module->outstanding_locks, (uint32_t)target); - ep = OSC_UCX_GET_EP(module->comm, target); - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -238,17 +223,21 @@ int ompi_osc_ucx_unlock_all(struct ompi_win_t *win) { int comm_size = ompi_comm_size(module->comm); int ret = OMPI_SUCCESS; + DBG_OUT("ompi_osc_ucx_unlock_all: start, mem = %p\n", (void *)module->mem); + if (module->epoch_type.access != PASSIVE_ALL_EPOCH) { return OMPI_ERR_RMA_SYNC; } assert(module->lock_count == 0); - ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0); if (ret != OMPI_SUCCESS) { return ret; } + DBG_OUT("ompi_osc_ucx_unlock_all: after flush, mem = %p\n", (void *)module->mem); + module->global_ops_num = 0; memset(module->per_target_ops_nums, 0, sizeof(int) * comm_size); @@ -261,12 +250,14 @@ int ompi_osc_ucx_unlock_all(struct ompi_win_t *win) { module->epoch_type.access = NONE_EPOCH; + DBG_OUT("ompi_osc_ucx_unlock_all: end, mem = %p\n", (void *)module->mem); + return ret; } int ompi_osc_ucx_sync(struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; - ucs_status_t status; + int ret = OMPI_SUCCESS; if (module->epoch_type.access != PASSIVE_EPOCH && module->epoch_type.access != PASSIVE_ALL_EPOCH) { @@ -275,27 +266,24 @@ int ompi_osc_ucx_sync(struct ompi_win_t *win) { opal_atomic_mb(); - status = ucp_worker_fence(mca_osc_ucx_component.ucp_worker); - if (status != UCS_OK) { - OSC_UCX_VERBOSE(1, "ucp_worker_fence failed: %d", status); - return OMPI_ERROR; + ret = opal_common_ucx_mem_fence(module->mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); } - return OMPI_SUCCESS; + return ret; } int ompi_osc_ucx_flush(int target, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ucp_ep_h ep; - int ret; + int ret = OMPI_SUCCESS; if (module->epoch_type.access != PASSIVE_EPOCH && module->epoch_type.access != PASSIVE_ALL_EPOCH) { return OMPI_ERR_RMA_SYNC; } - ep = OSC_UCX_GET_EP(module->comm, target); - ret = opal_common_ucx_ep_flush(ep, mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_EP, target); if (ret != OMPI_SUCCESS) { return ret; } @@ -308,14 +296,14 @@ int ompi_osc_ucx_flush(int target, struct ompi_win_t *win) { int ompi_osc_ucx_flush_all(struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; - int ret; + int ret = OMPI_SUCCESS; if (module->epoch_type.access != PASSIVE_EPOCH && module->epoch_type.access != PASSIVE_ALL_EPOCH) { return OMPI_ERR_RMA_SYNC; } - ret = opal_common_ucx_worker_flush(mca_osc_ucx_component.ucp_worker); + ret = opal_common_ucx_mem_flush(module->mem, OPAL_COMMON_UCX_SCOPE_WORKER, 0); if (ret != OMPI_SUCCESS) { return ret; } diff --git a/ompi/mca/osc/ucx/osc_ucx_request.h b/ompi/mca/osc/ucx/osc_ucx_request.h index b33bc54c2de..86934ae30eb 100644 --- a/ompi/mca/osc/ucx/osc_ucx_request.h +++ b/ompi/mca/osc/ucx/osc_ucx_request.h @@ -32,9 +32,8 @@ typedef struct ompi_osc_ucx_internal_request { do { \ item = opal_free_list_get(&mca_osc_ucx_component.requests); \ if (item == NULL) { \ - if (mca_osc_ucx_component.ucp_worker != NULL && \ - mca_osc_ucx_component.num_incomplete_req_ops > 0) { \ - ucp_worker_progress(mca_osc_ucx_component.ucp_worker); \ + if (mca_osc_ucx_component.num_incomplete_req_ops > 0) { \ + opal_common_ucx_workers_progress(mca_osc_ucx_component.wpool); \ } \ } \ } while (item == NULL); \ diff --git a/opal/mca/common/ucx/common_ucx.c b/opal/mca/common/ucx/common_ucx.c index 84e26b221d3..b942b259ad5 100644 --- a/opal/mca/common/ucx/common_ucx.c +++ b/opal/mca/common/ucx/common_ucx.c @@ -17,6 +17,95 @@ #include + +/***********************************************************************/ + +typedef struct { + opal_mutex_t mutex; + ucp_worker_h worker; + ucp_ep_h *endpoints; + size_t comm_size; +} _worker_info_t; + +typedef struct { + int ctx_id; + // TODO: make sure that this is being set by external thread + int is_freed; + opal_common_ucx_ctx_t *gctx; + _worker_info_t *winfo; +} _tlocal_ctx_t; + +typedef struct { + _worker_info_t *worker; + ucp_rkey_h *rkeys; +} _mem_info_t; + +typedef struct { + int mem_id; + int is_freed; + opal_common_ucx_mem_t *gmem; + _mem_info_t *mem; +} _tlocal_mem_t; + +typedef struct { + opal_list_item_t super; + _worker_info_t *ptr; +} _idle_list_item_t; + +OBJ_CLASS_DECLARATION(_idle_list_item_t); +OBJ_CLASS_INSTANCE(_idle_list_item_t, opal_list_item_t, NULL, NULL); + +typedef struct { + opal_list_item_t super; + _tlocal_ctx_t *ptr; +} _worker_list_item_t; + +OBJ_CLASS_DECLARATION(_worker_list_item_t); +OBJ_CLASS_INSTANCE(_worker_list_item_t, opal_list_item_t, NULL, NULL); + +typedef struct { + opal_list_item_t super; + _tlocal_mem_t *ptr; +} _mem_region_list_item_t; + +OBJ_CLASS_DECLARATION(_mem_region_list_item_t); +OBJ_CLASS_INSTANCE(_mem_region_list_item_t, opal_list_item_t, NULL, NULL); + +/* thread-local table */ +typedef struct { + opal_list_item_t super; + opal_common_ucx_wpool_t *wpool; + _tlocal_ctx_t **ctx_tbl; + size_t ctx_tbl_size; + _tlocal_mem_t **mem_tbl; + size_t mem_tbl_size; +} _tlocal_table_t; + +OBJ_CLASS_DECLARATION(_tlocal_table_t); +OBJ_CLASS_INSTANCE(_tlocal_table_t, opal_list_item_t, NULL, NULL); + +static pthread_key_t _tlocal_key = {0}; + + +static int _tlocal_tls_ctxtbl_extend(_tlocal_table_t *tbl, size_t append); +static int _tlocal_tls_memtbl_extend(_tlocal_table_t *tbl, size_t append); +static _tlocal_table_t* _common_ucx_tls_init(opal_common_ucx_wpool_t *wpool); +static void _common_ucx_tls_cleanup(_tlocal_table_t *tls); +static inline _tlocal_ctx_t *_tlocal_ctx_search(_tlocal_table_t *tls, int ctx_id); +static int _tlocal_ctx_record_cleanup(_tlocal_ctx_t *ctx_rec); +static _tlocal_ctx_t *_tlocal_add_ctx(_tlocal_table_t *tls, opal_common_ucx_ctx_t *ctx); +static int _tlocal_ctx_connect(_tlocal_ctx_t *ctx, int target); +static int _tlocal_ctx_release(opal_common_ucx_ctx_t *ctx); +static inline _tlocal_mem_t *_tlocal_search_mem(_tlocal_table_t *tls, int mem_id); +static _tlocal_mem_t *_tlocal_add_mem(_tlocal_table_t *tls, opal_common_ucx_mem_t *mem); +static int _tlocal_mem_create_rkey(_tlocal_mem_t *mem_rec, ucp_ep_h ep, int target); +// TOD: Return the error from it +static void _tlocal_mem_record_cleanup(_tlocal_mem_t *mem_rec); + + +__thread FILE *tls_pf = NULL; +__thread int initialized = 0; + /***********************************************************************/ extern mca_base_framework_t opal_memory_base_framework; @@ -218,3 +307,1311 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, s return OPAL_SUCCESS; } +/***********************************************************************/ + +static inline void _cleanup_tlocal(void *arg) +{ + _tlocal_table_t *item = NULL, *next; + _tlocal_table_t *tls = (_tlocal_table_t *)arg; + opal_common_ucx_wpool_t *wpool = NULL; + + DBG_OUT("_cleanup_tlocal: start\n"); + + if (NULL == tls) { + return; + } + + wpool = tls->wpool; + /* 1. Remove us from tls_list */ + tls->wpool = wpool; + opal_mutex_lock(&wpool->mutex); + OPAL_LIST_FOREACH_SAFE(item, next, &wpool->tls_list, _tlocal_table_t) { + if (item == tls) { + opal_list_remove_item(&wpool->tls_list, &item->super); + break; + } + } + opal_mutex_unlock(&wpool->mutex); + _common_ucx_tls_cleanup(tls); +} + +static +ucp_worker_h _create_ctx_worker(opal_common_ucx_wpool_t *wpool) +{ + ucp_worker_params_t worker_params; + ucp_worker_h worker; + ucs_status_t status; + + memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + status = ucp_worker_create(wpool->ucp_ctx, &worker_params, &worker); + if (UCS_OK != status) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_worker_create failed: %d", status); + return NULL; + } + + DBG_OUT("_create_ctx_worker: worker = %p\n", (void *)worker); + + return worker; +} + +static +int _wpool_add_to_idle(opal_common_ucx_wpool_t *wpool, _worker_info_t *winfo) +{ + _idle_list_item_t *item; + + if(winfo->comm_size != 0) { + size_t i; + for (i = 0; i < winfo->comm_size; i++) { + if (NULL != winfo->endpoints[i]){ + ucp_ep_destroy(winfo->endpoints[i]); + } + } + free(winfo->endpoints); + winfo->endpoints = NULL; + winfo->comm_size = 0; + } + + item = OBJ_NEW(_idle_list_item_t); + if (NULL == item) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + item->ptr = winfo; + + opal_mutex_lock(&wpool->mutex); + opal_list_append(&wpool->idle_workers, &item->super); + opal_mutex_unlock(&wpool->mutex); + + DBG_OUT("_wpool_add_to_idle: wpool = %p winfo = %p\n", (void *)wpool, (void *)winfo); + return OPAL_SUCCESS; +} + +static +_worker_info_t* _wpool_remove_from_idle(opal_common_ucx_wpool_t *wpool) +{ + _worker_info_t *wkr = NULL; + _idle_list_item_t *item = NULL; + + opal_mutex_lock(&wpool->mutex); + if (!opal_list_is_empty(&wpool->idle_workers)) { + item = (_idle_list_item_t *)opal_list_get_first(&wpool->idle_workers); + opal_list_remove_item(&wpool->idle_workers, &item->super); + } + opal_mutex_unlock(&wpool->mutex); + + if (item != NULL) { + wkr = item->ptr; + OBJ_RELEASE(item); + } + + DBG_OUT("_wpool_remove_from_idle: wpool = %p\n", (void *)wpool); + return wkr; +} + +OPAL_DECLSPEC +opal_common_ucx_wpool_t * opal_common_ucx_wpool_allocate(void) +{ + opal_common_ucx_wpool_t *ptr = calloc(1, sizeof(opal_common_ucx_wpool_t)); + ptr->refcnt = 0; + + DBG_OUT("opal_common_ucx_wpool_allocate: wpool = %p\n", (void *)ptr); + return ptr; +} + +OPAL_DECLSPEC +void opal_common_ucx_wpool_free(opal_common_ucx_wpool_t *wpool) +{ + assert(wpool->refcnt == 0); + + DBG_OUT("opal_common_ucx_wpool_free: wpool = %p\n", (void *)wpool); + + free(wpool); +} + +OPAL_DECLSPEC +int opal_common_ucx_wpool_init(opal_common_ucx_wpool_t *wpool, + int proc_world_size, + ucp_request_init_callback_t req_init_ptr, + size_t req_size, bool enable_mt) +{ + ucp_config_t *config = NULL; + ucp_params_t context_params; + _worker_info_t *wkr; + ucs_status_t status; + int rc = OPAL_SUCCESS; + + if (wpool->refcnt > 0) { + wpool->refcnt++; + return rc; + } + + wpool->refcnt++; + wpool->cur_ctxid = wpool->cur_memid = 0; + OBJ_CONSTRUCT(&wpool->mutex, opal_mutex_t); + OBJ_CONSTRUCT(&wpool->tls_list, opal_list_t); + + status = ucp_config_read("MPI", NULL, &config); + if (UCS_OK != status) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_config_read failed: %d", status); + return OPAL_ERROR; + } + + /* initialize UCP context */ + memset(&context_params, 0, sizeof(context_params)); + context_params.field_mask = UCP_PARAM_FIELD_FEATURES | + UCP_PARAM_FIELD_MT_WORKERS_SHARED | + UCP_PARAM_FIELD_ESTIMATED_NUM_EPS | + UCP_PARAM_FIELD_REQUEST_INIT | + UCP_PARAM_FIELD_REQUEST_SIZE; + context_params.features = UCP_FEATURE_RMA | UCP_FEATURE_AMO32 | UCP_FEATURE_AMO64; + context_params.mt_workers_shared = (enable_mt ? 1 : 0); + context_params.estimated_num_eps = proc_world_size; + context_params.request_init = req_init_ptr; + context_params.request_size = req_size; + + status = ucp_init(&context_params, config, &wpool->ucp_ctx); + ucp_config_release(config); + if (UCS_OK != status) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_init failed: %d", status); + rc = OPAL_ERROR; + goto err_ucp_init; + } + + /* create recv worker and add to idle pool */ + OBJ_CONSTRUCT(&wpool->idle_workers, opal_list_t); + wpool->recv_worker = _create_ctx_worker(wpool); + if (wpool->recv_worker == NULL) { + MCA_COMMON_UCX_VERBOSE(1, "_create_ctx_worker failed"); + rc = OPAL_ERROR; + goto err_worker_create; + } + + wkr = calloc(1, sizeof(_worker_info_t)); + OBJ_CONSTRUCT(&wkr->mutex, opal_mutex_t); + + wkr->worker = wpool->recv_worker; + wkr->endpoints = NULL; + wkr->comm_size = 0; + + status = ucp_worker_get_address(wpool->recv_worker, + &wpool->recv_waddr, &wpool->recv_waddr_len); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_worker_get_address failed: %d", status); + rc = OPAL_ERROR; + goto err_get_addr; + } + + rc = _wpool_add_to_idle(wpool, wkr); + if (rc) { + goto err_wpool_add; + } + + pthread_key_create(&_tlocal_key, _cleanup_tlocal); + + DBG_OUT("opal_common_ucx_wpool_init: wpool = %p\n", (void *)wpool); + return rc; + +err_wpool_add: + free(wpool->recv_waddr); +err_get_addr: + if (NULL != wpool->recv_worker) { + ucp_worker_destroy(wpool->recv_worker); + } + err_worker_create: + ucp_cleanup(wpool->ucp_ctx); + err_ucp_init: + return rc; +} + +OPAL_DECLSPEC +void opal_common_ucx_wpool_finalize(opal_common_ucx_wpool_t *wpool) +{ + _tlocal_table_t *tls_item = NULL, *tls_next; + + DBG_OUT("opal_common_ucx_wpool_finalize(start): wpool = %p\n", (void *)wpool); + + wpool->refcnt--; + if (wpool->refcnt > 0) { + DBG_OUT("opal_common_ucx_wpool_finalize: wpool = %p\n", (void *)wpool); + return; + } + + pthread_key_delete(_tlocal_key); + + opal_mutex_lock(&wpool->mutex); + OPAL_LIST_FOREACH_SAFE(tls_item, tls_next, &wpool->tls_list, _tlocal_table_t) { + opal_list_remove_item(&wpool->tls_list, &tls_item->super); + opal_mutex_unlock(&wpool->mutex); + + _common_ucx_tls_cleanup(tls_item); + + opal_mutex_lock(&wpool->mutex); + DBG_OUT("opal_common_ucx_wpool_finalize: cleanup wpool = %p\n", (void *)wpool); + } + opal_mutex_unlock(&wpool->mutex); + + /* Release the address here. recv worker will be released + * below along with other idle workers */ + ucp_worker_release_address(wpool->recv_worker, wpool->recv_waddr); + + opal_mutex_lock(&wpool->mutex); + /* Go over the list, free idle list items */ + if (!opal_list_is_empty(&wpool->idle_workers)) { + _idle_list_item_t *item, *next; + OPAL_LIST_FOREACH_SAFE(item, next, &wpool->idle_workers, _idle_list_item_t) { + _worker_info_t *curr_worker; + opal_list_remove_item(&wpool->idle_workers, &item->super); + curr_worker = item->ptr; + OBJ_DESTRUCT(&curr_worker->mutex); + ucp_worker_destroy(curr_worker->worker); + free(curr_worker); + OBJ_RELEASE(item); + } + } + opal_mutex_unlock(&wpool->mutex); + + OBJ_DESTRUCT(&wpool->idle_workers); + OBJ_DESTRUCT(&wpool->tls_list); + OBJ_DESTRUCT(&wpool->mutex); + ucp_cleanup(wpool->ucp_ctx); + DBG_OUT("opal_common_ucx_wpool_finalize: wpool = %p\n", (void *)wpool); + return; +} + +OPAL_DECLSPEC +int opal_common_ucx_ctx_create(opal_common_ucx_wpool_t *wpool, int comm_size, + opal_common_ucx_exchange_func_t exchange_func, + void *exchange_metadata, + opal_common_ucx_ctx_t **ctx_ptr) +{ + opal_common_ucx_ctx_t *ctx = calloc(1, sizeof(*ctx)); + int ret = OPAL_SUCCESS; + + ctx->ctx_id = OPAL_ATOMIC_ADD_FETCH32(&wpool->cur_ctxid, 1); + DBG_OUT("ctx_create: ctx_id = %d\n", (int)ctx->ctx_id); + + OBJ_CONSTRUCT(&ctx->mutex, opal_mutex_t); + OBJ_CONSTRUCT(&ctx->workers, opal_list_t); + ctx->wpool = wpool; + ctx->comm_size = comm_size; + + ctx->recv_worker_addrs = NULL; + ctx->recv_worker_displs = NULL; + ret = exchange_func(wpool->recv_waddr, wpool->recv_waddr_len, + &ctx->recv_worker_addrs, + &ctx->recv_worker_displs, exchange_metadata); + if (ret != OPAL_SUCCESS) { + goto error; + } + + (*ctx_ptr) = ctx; + DBG_OUT("opal_common_ucx_ctx_create: wpool = %p, (*ctx_ptr) = %p\n", (void *)wpool, (void *)(*ctx_ptr)); + return ret; + + error: + OBJ_DESTRUCT(&ctx->mutex); + OBJ_DESTRUCT(&ctx->workers); + free(ctx); + (*ctx_ptr) = NULL; + return ret; +} + +static void _common_ucx_ctx_free(opal_common_ucx_ctx_t *ctx) +{ + free(ctx->recv_worker_addrs); + free(ctx->recv_worker_displs); + OBJ_DESTRUCT(&ctx->mutex); + OBJ_DESTRUCT(&ctx->workers); + DBG_OUT("_common_ucx_ctx_free: ctx = %p\n", (void *)ctx); + free(ctx); +} + +OPAL_DECLSPEC void +opal_common_ucx_ctx_release(opal_common_ucx_ctx_t *ctx) +{ + // TODO: implement + DBG_OUT("opal_common_ucx_ctx_release: ctx = %p\n", (void *)ctx); + _tlocal_ctx_release(ctx); +} + +static int +_common_ucx_ctx_append(opal_common_ucx_ctx_t *ctx, _tlocal_ctx_t *ctx_rec) +{ + _worker_list_item_t *item = OBJ_NEW(_worker_list_item_t); + if (NULL == item) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + item->ptr = ctx_rec; + opal_mutex_lock(&ctx->mutex); + opal_list_append(&ctx->workers, &item->super); + opal_mutex_unlock(&ctx->mutex); + DBG_OUT("_common_ucx_ctx_append: ctx = %p, ctx_rec = %p\n", (void *)ctx, (void *)ctx_rec); + return OPAL_SUCCESS; +} + +static void +_common_ucx_ctx_remove(opal_common_ucx_ctx_t *ctx, _tlocal_ctx_t *ctx_rec) +{ + int can_free = 0; + _worker_list_item_t *item = NULL, *next; + + opal_mutex_lock(&ctx->mutex); + OPAL_LIST_FOREACH_SAFE(item, next, &ctx->workers, _worker_list_item_t) { + if (ctx_rec == item->ptr) { + opal_list_remove_item(&ctx->workers, &item->super); + OBJ_RELEASE(item); + break; + } + } + if (0 == opal_list_get_size(&ctx->workers)) { + can_free = 1; + } + opal_mutex_unlock(&ctx->mutex); + + if (can_free) { + /* All references to this data structure are removed + * we can safely release communication context structure */ + _common_ucx_ctx_free(ctx); + } + DBG_OUT("_common_ucx_ctx_remove: ctx = %p, ctx_rec = %p\n", (void *)ctx, (void *)ctx_rec); + return; +} + +static int _comm_ucx_mem_map(opal_common_ucx_wpool_t *wpool, + void **base, size_t size, ucp_mem_h *memh_ptr, + opal_common_ucx_mem_type_t mem_type) +{ + ucp_mem_map_params_t mem_params; + ucp_mem_attr_t mem_attrs; + ucs_status_t status; + int ret = OPAL_SUCCESS; + + memset(&mem_params, 0, sizeof(ucp_mem_map_params_t)); + mem_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_FLAGS; + mem_params.length = size; + if (mem_type == OPAL_COMMON_UCX_MEM_ALLOCATE_MAP) { + mem_params.address = NULL; + mem_params.flags = UCP_MEM_MAP_ALLOCATE; + } else { + mem_params.address = (*base); + } + + status = ucp_mem_map(wpool->ucp_ctx, &mem_params, memh_ptr); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_mem_map failed: %d", status); + ret = OPAL_ERROR; + return ret; + } + DBG_OUT("_comm_ucx_mem_map(after ucp_mem_map): memh = %p\n", (void *)(*memh_ptr)); + + mem_attrs.field_mask = UCP_MEM_ATTR_FIELD_ADDRESS | UCP_MEM_ATTR_FIELD_LENGTH; + status = ucp_mem_query((*memh_ptr), &mem_attrs); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_mem_query failed: %d", status); + ret = OPAL_ERROR; + goto error; + } + DBG_OUT("_comm_ucx_mem_map(after ucp_mem_query): memh = %p\n", (void *)(*memh_ptr)); + + assert(mem_attrs.length >= size); + if (mem_type != OPAL_COMMON_UCX_MEM_ALLOCATE_MAP) { + assert(mem_attrs.address == (*base)); + } else { + (*base) = mem_attrs.address; + } + + DBG_OUT("_comm_ucx_mem_map(end): wpool = %p, addr = %p size = %d memh = %p\n", + (void *)wpool, (void *)(*base), (int)size, (void *)(*memh_ptr)); + return ret; + error: + ucp_mem_unmap(wpool->ucp_ctx, (*memh_ptr)); + return ret; +} + + +OPAL_DECLSPEC +int opal_common_ucx_mem_create(opal_common_ucx_ctx_t *ctx, int comm_size, + void **mem_base, size_t mem_size, + opal_common_ucx_mem_type_t mem_type, + opal_common_ucx_exchange_func_t exchange_func, + void *exchange_metadata, + opal_common_ucx_mem_t **mem_ptr) +{ + opal_common_ucx_mem_t *mem = calloc(1, sizeof(*mem)); + void *rkey_addr = NULL; + size_t rkey_addr_len; + ucs_status_t status; + int ret = OPAL_SUCCESS; + + mem->mem_id = OPAL_ATOMIC_ADD_FETCH32(&ctx->wpool->cur_memid, 1); + + DBG_OUT("mem_create: mem_id = %d\n", (int)mem->mem_id); + + OBJ_CONSTRUCT(&mem->mutex, opal_mutex_t); + OBJ_CONSTRUCT(&mem->registrations, opal_list_t); + mem->ctx = ctx; + mem->mem_addrs = NULL; + mem->mem_displs = NULL; + + ret = _comm_ucx_mem_map(ctx->wpool, mem_base, mem_size, &mem->memh, mem_type); + if (ret != OPAL_SUCCESS) { + MCA_COMMON_UCX_VERBOSE(1, "_comm_ucx_mem_map failed: %d", ret); + goto error_mem_map; + } + DBG_OUT("opal_common_ucx_mem_create(after _comm_ucx_mem_map): base = %p, memh = %p\n", + (void *)(*mem_base), (void *)(mem->memh)); + + status = ucp_rkey_pack(ctx->wpool->ucp_ctx, mem->memh, + &rkey_addr, &rkey_addr_len); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_rkey_pack failed: %d", status); + ret = OPAL_ERROR; + goto error_rkey_pack; + } + DBG_OUT("opal_common_ucx_mem_create(after ucp_rkey_pack): rkey_addr = %p, rkey_addr_len = %d\n", + (void *)rkey_addr, (int)rkey_addr_len); + + ret = exchange_func(rkey_addr, rkey_addr_len, + &mem->mem_addrs, &mem->mem_displs, exchange_metadata); + DBG_OUT("opal_common_ucx_mem_create(after exchange_func): rkey_addr = %p, rkey_addr_len = %d mem_addrs = %p mem_displs = %p\n", + (void *)rkey_addr, (int)rkey_addr_len, (void *)mem->mem_addrs, (void *)mem->mem_displs); + + ucp_rkey_buffer_release(rkey_addr); + if (ret != OPAL_SUCCESS) { + goto error_rkey_pack; + } + + (*mem_ptr) = mem; + + DBG_OUT("opal_common_ucx_mem_create(end): mem = %p\n", (void *)mem); + return ret; + + error_rkey_pack: + ucp_mem_unmap(ctx->wpool->ucp_ctx, mem->memh); + error_mem_map: + OBJ_DESTRUCT(&mem->mutex); + OBJ_DESTRUCT(&mem->registrations); + free(mem); + (*mem_ptr) = NULL; + return ret; +} + +static void _common_ucx_mem_free(opal_common_ucx_mem_t *mem) +{ + free(mem->mem_addrs); + free(mem->mem_displs); + ucp_mem_unmap(mem->ctx->wpool->ucp_ctx, mem->memh); + OBJ_DESTRUCT(&mem->mutex); + OBJ_DESTRUCT(&mem->registrations); + DBG_OUT("_common_ucx_mem_free: mem = %p\n", (void *)mem); + free(mem); +} + +static int +_common_ucx_mem_append(opal_common_ucx_mem_t *mem, + _tlocal_mem_t *mem_rec) +{ + _mem_region_list_item_t *item = OBJ_NEW(_mem_region_list_item_t); + if (NULL == item) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + item->ptr = mem_rec; + opal_mutex_lock(&mem->mutex); + opal_list_append(&mem->registrations, &item->super); + opal_mutex_unlock(&mem->mutex); + DBG_OUT("_common_ucx_mem_append: mem = %p, mem_rec = %p\n", (void *)mem, (void *)mem_rec); + return OPAL_SUCCESS; +} + +static void +_common_ucx_mem_remove(opal_common_ucx_mem_t *mem, _tlocal_mem_t *mem_rec) +{ + int can_free = 0; + _mem_region_list_item_t *item = NULL, *next; + + opal_mutex_lock(&mem->mutex); + OPAL_LIST_FOREACH_SAFE(item, next, &mem->registrations, _mem_region_list_item_t) { + if (mem_rec == item->ptr) { + opal_list_remove_item(&mem->registrations, &item->super); + OBJ_RELEASE(item); + break; + } + } + if (0 == opal_list_get_size(&mem->registrations)) { + can_free = 1; + } + opal_mutex_unlock(&mem->mutex); + + if (can_free) { + /* All references to this data structure are removed + * we can safely release communication context structure */ + _common_ucx_mem_free(mem); + } + DBG_OUT("_common_ucx_mem_remove(end): mem = %p mem_rec = %p\n", (void *)mem, (void *)mem_rec); + return; +} + + +// TODO: don't want to inline this function +static _tlocal_table_t* _common_ucx_tls_init(opal_common_ucx_wpool_t *wpool) +{ + _tlocal_table_t *tls = OBJ_NEW(_tlocal_table_t); + + if (tls == NULL) { + // return OPAL_ERR_OUT_OF_RESOURCE + return NULL; + } + + tls->ctx_tbl = NULL; + tls->ctx_tbl_size = 0; + tls->mem_tbl = NULL; + tls->mem_tbl_size = 0; + + /* Add this TLS to the global wpool structure for future + * cleanup purposes */ + tls->wpool = wpool; + opal_mutex_lock(&wpool->mutex); + opal_list_append(&wpool->tls_list, &tls->super); + opal_mutex_unlock(&wpool->mutex); + + if(_tlocal_tls_ctxtbl_extend(tls, 4)){ + DBG_OUT("_tlocal_tls_ctxtbl_extend failed\n"); + // TODO: handle error + } + if(_tlocal_tls_memtbl_extend(tls, 4)) { + DBG_OUT("_tlocal_tls_memtbl_extend failed\n"); + // TODO: handle error + } + + pthread_setspecific(_tlocal_key, tls); + DBG_OUT("_common_ucx_tls_init(end): wpool = %p\n", (void *)wpool); + return tls; +} + +static inline _tlocal_table_t * +_tlocal_get_tls(opal_common_ucx_wpool_t *wpool){ + _tlocal_table_t *tls = pthread_getspecific(_tlocal_key); + if( OPAL_UNLIKELY(NULL == tls) ) { + tls = _common_ucx_tls_init(wpool); + } + DBG_OUT("_tlocal_get_tls(end): wpool = %p tls = %p\n", (void *)wpool, (void *)tls); + return tls; +} + +_worker_list_item_t *item = NULL, *next; + +// TODO: don't want to inline this function +static void _common_ucx_tls_cleanup(_tlocal_table_t *tls) +{ + size_t i, size; + + // Cleanup memory table + size = tls->mem_tbl_size; + for (i = 0; i < size; i++) { + if (!tls->mem_tbl[i]->mem_id){ + continue; + } + _tlocal_mem_record_cleanup(tls->mem_tbl[i]); + free(tls->mem_tbl[i]); + } + + // Cleanup ctx table + size = tls->ctx_tbl_size; + for (i = 0; i < size; i++) { + if (!tls->ctx_tbl[i]->ctx_id){ + continue; + } + _tlocal_ctx_record_cleanup(tls->ctx_tbl[i]); + free(tls->ctx_tbl[i]); + } + + pthread_setspecific(_tlocal_key, NULL); + DBG_OUT("_common_ucx_tls_cleanup(end): tls = %p\n", (void *)tls); + + OBJ_RELEASE(tls); + + return; +} + +static int +_tlocal_tls_get_worker(_tlocal_table_t *tls, _worker_info_t **_winfo) +{ + _worker_info_t *winfo; + *_winfo = NULL; + winfo = _wpool_remove_from_idle(tls->wpool); + if (!winfo) { + winfo = calloc(1, sizeof(*winfo)); + if (!winfo) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + OBJ_CONSTRUCT(&winfo->mutex, opal_mutex_t); + winfo->worker = _create_ctx_worker(tls->wpool); + winfo->endpoints = NULL; + winfo->comm_size = 0; + } + *_winfo = winfo; + DBG_OUT("_tlocal_tls_get_worker(end): tls = %p winfo = %p\n", (void *)tls, (void *)winfo); + + return OPAL_SUCCESS; +} + +static int +_tlocal_tls_ctxtbl_extend(_tlocal_table_t *tbl, size_t append) +{ + size_t i; + size_t newsize = (tbl->ctx_tbl_size + append); + tbl->ctx_tbl = realloc(tbl->ctx_tbl, newsize * sizeof(*tbl->ctx_tbl)); + for (i = tbl->ctx_tbl_size; i < newsize; i++) { + tbl->ctx_tbl[i] = calloc(1, sizeof(*tbl->ctx_tbl[i])); + if (NULL == tbl->ctx_tbl[i]) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + + } + tbl->ctx_tbl_size = newsize; + DBG_OUT("_tlocal_tls_ctxtbl_extend(end): tbl = %p\n", (void *)tbl); + return OPAL_SUCCESS; +} +static int +_tlocal_tls_memtbl_extend(_tlocal_table_t *tbl, size_t append) +{ + size_t i; + size_t newsize = (tbl->mem_tbl_size + append); + + tbl->mem_tbl = realloc(tbl->mem_tbl, newsize * sizeof(*tbl->mem_tbl)); + for (i = tbl->mem_tbl_size; i < tbl->mem_tbl_size + append; i++) { + tbl->mem_tbl[i] = calloc(1, sizeof(*tbl->mem_tbl[i])); + if (NULL == tbl->mem_tbl[i]) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + } + tbl->mem_tbl_size = newsize; + DBG_OUT("_tlocal_tls_memtbl_extend(end): tbl = %p\n", (void *)tbl); + return OPAL_SUCCESS; +} + + +static inline _tlocal_ctx_t * +_tlocal_ctx_search(_tlocal_table_t *tls, int ctx_id) +{ + size_t i; + for(i=0; ictx_tbl_size; i++) { + if( tls->ctx_tbl[i]->ctx_id == ctx_id){ + return tls->ctx_tbl[i]; + } + } + DBG_OUT("_tlocal_ctx_search: tls = %p ctx_id = %d\n", (void *)tls, ctx_id); + return NULL; +} + +static int +_tlocal_ctx_record_cleanup(_tlocal_ctx_t *ctx_rec) +{ + int rc; + if (0 == ctx_rec->ctx_id) { + return OPAL_SUCCESS; + } + /* Remove myself from the communication context structure + * This may result in context release as we are using + * delayed cleanup */ + _common_ucx_ctx_remove(ctx_rec->gctx, ctx_rec); + + /* Return the worker back to the + * This may result in context release as we are using + * delayed cleanup */ + rc = _wpool_add_to_idle(ctx_rec->gctx->wpool, ctx_rec->winfo); + if (rc) { + return rc; + } + memset(ctx_rec, 0, sizeof(*ctx_rec)); + DBG_OUT("_tlocal_cleanup_ctx_record(end): ctx_rec = %p\n", (void *)ctx_rec); + return OPAL_SUCCESS; +} + +// TODO: Don't want to inline this (slow path) +static _tlocal_ctx_t * +_tlocal_add_ctx(_tlocal_table_t *tls, opal_common_ucx_ctx_t *ctx) +{ + size_t i; + int rc; + + /* Try to find available spot in the table */ + for (i=0; ictx_tbl_size; i++) { + if (0 == tls->ctx_tbl[i]->ctx_id) { + /* Found clean record */ + break; + } + if (tls->ctx_tbl[i]->is_freed ) { + /* Found dirty record, need to clean first */ + _tlocal_ctx_record_cleanup(tls->ctx_tbl[i]); + break; + } + } + + if( i >= tls->ctx_tbl_size ){ + i = tls->ctx_tbl_size; + rc = _tlocal_tls_ctxtbl_extend(tls, 4); + if (rc) { + //TODO: error out + return NULL; + } + } + + tls->ctx_tbl[i]->ctx_id = ctx->ctx_id; + tls->ctx_tbl[i]->gctx = ctx; + rc = _tlocal_tls_get_worker(tls, &tls->ctx_tbl[i]->winfo); + if (rc) { + //TODO: error out + return NULL; + } + DBG_OUT("_tlocal_add_ctx(after _tlocal_tls_get_worker): tls = %p winfo = %p\n", + (void *)tls, (void *)tls->ctx_tbl[i]->winfo); + tls->ctx_tbl[i]->winfo->endpoints = calloc(ctx->comm_size, sizeof(ucp_ep_h)); + tls->ctx_tbl[i]->winfo->comm_size = ctx->comm_size; + + + /* Make sure that we completed all the data structures before + * placing the item to the list + * NOTE: essentially we don't need this as list append is an + * operation protected by mutex + */ + opal_atomic_wmb(); + + /* add this worker into the context list */ + rc = _common_ucx_ctx_append(ctx, tls->ctx_tbl[i]); + if (rc) { + //TODO: error out + return NULL; + } + DBG_OUT("_tlocal_add_ctx(after _common_ucx_ctx_append): ctx = %p tls->ctx_tbl = %p\n", + (void *)ctx, (void *)tls->ctx_tbl); + + /* All good - return the record */ + return tls->ctx_tbl[i]; +} + +static int _tlocal_ctx_connect(_tlocal_ctx_t *ctx_rec, int target) +{ + ucp_ep_params_t ep_params; + _worker_info_t *winfo = ctx_rec->winfo; + opal_common_ucx_ctx_t *gctx = ctx_rec->gctx; + ucs_status_t status; + int displ; + + memset(&ep_params, 0, sizeof(ucp_ep_params_t)); + ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + + opal_mutex_lock(&winfo->mutex); + displ = gctx->recv_worker_displs[target]; + ep_params.address = (ucp_address_t *)&(gctx->recv_worker_addrs[displ]); + status = ucp_ep_create(winfo->worker, &ep_params, &winfo->endpoints[target]); + if (status != UCS_OK) { + opal_mutex_unlock(&winfo->mutex); + MCA_COMMON_UCX_VERBOSE(1, "ucp_ep_create failed: %d", status); + return OPAL_ERROR; + } + DBG_OUT("_tlocal_ctx_connect(after ucp_ep_create): worker = %p ep = %p\n", + (void *)winfo->worker, (void *)winfo->endpoints[target]); + opal_mutex_unlock(&winfo->mutex); + return OPAL_SUCCESS; +} + +static int _tlocal_ctx_release(opal_common_ucx_ctx_t *ctx) +{ + _tlocal_table_t * tls = _tlocal_get_tls(ctx->wpool); + _tlocal_ctx_t *ctx_rec = _tlocal_ctx_search(tls, ctx->ctx_id); + int rc = OPAL_SUCCESS; + + if (NULL == ctx_rec) { + /* we haven't participated in this context */ + return OPAL_SUCCESS; + } + + /* May free the ctx structure. Do not use it */ + _common_ucx_ctx_remove(ctx, ctx_rec); + DBG_OUT("_tlocal_ctx_release(after _common_ucx_ctx_remove): ctx = %p ctx_rec = %p\n", + (void *)ctx, (void *)ctx_rec); + rc = _wpool_add_to_idle(tls->wpool, ctx_rec->winfo); + DBG_OUT("_tlocal_ctx_release(after _wpool_add_to_idle): wpool = %p winfo = %p\n", + (void *)tls->wpool, (void *)ctx_rec->winfo); + + ctx_rec->ctx_id = 0; + ctx_rec->is_freed = 0; + ctx_rec->gctx = NULL; + ctx_rec->winfo = NULL; + + return rc; +} + +static inline _tlocal_mem_t * +_tlocal_search_mem(_tlocal_table_t *tls, int mem_id) +{ + size_t i; + DBG_OUT("_tlocal_search_mem(begin): tls = %p mem_id = %d\n", + (void *)tls, (int)mem_id); + for(i=0; imem_tbl_size; i++) { + if( tls->mem_tbl[i]->mem_id == mem_id){ + return tls->mem_tbl[i]; + } + } + return NULL; +} + + +static void +_tlocal_mem_record_cleanup(_tlocal_mem_t *mem_rec) +{ + size_t i; + DBG_OUT("_tlocal_mem_record_cleanup: record=%p, is_freed = %d\n", + (void *)mem_rec, mem_rec->is_freed); + if (mem_rec->is_freed) { + return; + } + /* Remove myself from the memory context structure + * This may result in context release as we are using + * delayed cleanup */ + _common_ucx_mem_remove(mem_rec->gmem, mem_rec); + DBG_OUT("_tlocal_mem_record_cleanup(_common_ucx_mem_remove): gmem = %p mem_rec = %p\n", + (void *)mem_rec->gmem, (void *)mem_rec); + + for(i = 0; i < mem_rec->gmem->ctx->comm_size; i++) { + if (mem_rec->mem->rkeys[i]) { + ucp_rkey_destroy(mem_rec->mem->rkeys[i]); + DBG_OUT("_tlocal_mem_record_cleanup(after ucp_rkey_destroy): rkey_entry = %p\n", + (void *)mem_rec->mem->rkeys[i]); + } + } + + free(mem_rec->mem->rkeys); + free(mem_rec->mem); + + memset(mem_rec, 0, sizeof(*mem_rec)); +} + + +// TODO: Don't want to inline this (slow path) +static _tlocal_mem_t *_tlocal_add_mem(_tlocal_table_t *tls, + opal_common_ucx_mem_t *mem) +{ + size_t i; + _tlocal_ctx_t *ctx_rec = NULL; + int rc = OPAL_SUCCESS; + + /* Try to find available spot in the table */ + for (i=0; imem_tbl_size; i++) { + if (0 == tls->mem_tbl[i]->mem_id) { + /* Found a clear record */ + } + if (tls->mem_tbl[i]->is_freed) { + /* Found a dirty record. Need to clean it first */ + _tlocal_mem_record_cleanup(tls->mem_tbl[i]); + DBG_OUT("_tlocal_add_mem(after _tlocal_mem_record_cleanup): tls = %p mem_tbl_entry = %p\n", + (void *)tls, (void *)tls->mem_tbl[i]); + break; + } + } + + if( i >= tls->mem_tbl_size ){ + i = tls->mem_tbl_size; + rc = _tlocal_tls_memtbl_extend(tls, 4); + if (rc != OPAL_SUCCESS) { + //TODO: error out + return NULL; + } + DBG_OUT("_tlocal_add_mem(after _tlocal_tls_memtbl_extend): tls = %p\n", + (void *)tls); + } + tls->mem_tbl[i]->mem_id = mem->mem_id; + tls->mem_tbl[i]->gmem = mem; + tls->mem_tbl[i]->is_freed = 0; + tls->mem_tbl[i]->mem = calloc(1, sizeof(*tls->mem_tbl[i]->mem)); + ctx_rec = _tlocal_ctx_search(tls, mem->ctx->ctx_id); + if (NULL == ctx_rec) { + // TODO: act accordingly - cleanup + return NULL; + } + DBG_OUT("_tlocal_add_mem(after _tlocal_ctx_search): tls = %p, ctx_id = %d\n", + (void *)tls, (int)mem->ctx->ctx_id); + + tls->mem_tbl[i]->mem->worker = ctx_rec->winfo; + tls->mem_tbl[i]->mem->rkeys = calloc(mem->ctx->comm_size, + sizeof(*tls->mem_tbl[i]->mem->rkeys)); + + + /* Make sure that we completed all the data structures before + * placing the item to the list + * NOTE: essentially we don't need this as list append is an + * operation protected by mutex + */ + opal_atomic_wmb(); + + rc = _common_ucx_mem_append(mem, tls->mem_tbl[i]); + if (rc) { + // TODO: error handling + return NULL; + } + DBG_OUT("_tlocal_add_mem(after _common_ucx_mem_append): mem = %p, mem_tbl_entry = %p\n", + (void *)mem, (void *)tls->mem_tbl[i]); + + return tls->mem_tbl[i]; +} + +static int _tlocal_mem_create_rkey(_tlocal_mem_t *mem_rec, ucp_ep_h ep, int target) +{ + _mem_info_t *minfo = mem_rec->mem; + opal_common_ucx_mem_t *gmem = mem_rec->gmem; + int displ = gmem->mem_displs[target]; + ucs_status_t status; + + status = ucp_ep_rkey_unpack(ep, &gmem->mem_addrs[displ], + &minfo->rkeys[target]); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status); + return OPAL_ERROR; + } + DBG_OUT("_tlocal_mem_create_rkey(after ucp_ep_rkey_unpack): mem_rec = %p ep = %p target = %d\n", + (void *)mem_rec, (void *)ep, target); + return OPAL_SUCCESS; +} + +static inline int _tlocal_fetch(opal_common_ucx_mem_t *mem, int target, + ucp_ep_h *_ep, ucp_rkey_h *_rkey, + _worker_info_t **_winfo) +{ + _tlocal_table_t *tls = NULL; + _tlocal_ctx_t *ctx_rec = NULL; + _worker_info_t *winfo = NULL; + _tlocal_mem_t *mem_rec = NULL; + _mem_info_t *mem_info = NULL; + ucp_ep_h ep; + ucp_rkey_h rkey; + int rc = OPAL_SUCCESS; + + DBG_OUT("_tlocal_fetch: starttls \n"); + + tls = _tlocal_get_tls(mem->ctx->wpool); + + DBG_OUT("_tlocal_fetch: tls = %p\n",(void*)tls); + + /* Obtain the worker structure */ + ctx_rec = _tlocal_ctx_search(tls, mem->ctx->ctx_id); + + DBG_OUT("_tlocal_fetch(after _tlocal_ctx_search): ctx_id = %d, ctx_rec=%p\n", + (int)mem->ctx->ctx_id, (void *)ctx_rec); + if (OPAL_UNLIKELY(NULL == ctx_rec)) { + ctx_rec = _tlocal_add_ctx(tls, mem->ctx); + if (NULL == ctx_rec) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + DBG_OUT("_tlocal_fetch(after _tlocal_add_ctx): tls = %p ctx = %p\n", (void *)tls, (void *)mem->ctx); + } + winfo = ctx_rec->winfo; + DBG_OUT("_tlocal_fetch: winfo = %p ctx=%p\n", (void *)winfo, (void *)mem->ctx); + + /* Obtain the endpoint */ + if (OPAL_UNLIKELY(NULL == winfo->endpoints[target])) { + rc = _tlocal_ctx_connect(ctx_rec, target); + if (rc != OPAL_SUCCESS) { + return rc; + } + DBG_OUT("_tlocal_fetch(after _tlocal_ctx_connect): ctx_rec = %p target = %d\n", (void *)ctx_rec, target); + } + ep = winfo->endpoints[target]; + DBG_OUT("_tlocal_fetch: ep = %p\n", (void *)ep); + + /* Obtain the memory region info */ + mem_rec = _tlocal_search_mem(tls, mem->mem_id); + DBG_OUT("_tlocal_fetch: tls = %p mem_rec = %p mem_id = %d\n", (void *)tls, (void *)mem_rec, (int)mem->mem_id); + if (OPAL_UNLIKELY(mem_rec == NULL)) { + mem_rec = _tlocal_add_mem(tls, mem); + DBG_OUT("_tlocal_fetch(after _tlocal_add_mem): tls = %p mem = %p\n", (void *)tls, (void *)mem); + if (NULL == mem_rec) { + return OPAL_ERR_OUT_OF_RESOURCE; + } + } + mem_info = mem_rec->mem; + DBG_OUT("_tlocal_fetch: mem_info = %p\n", (void *)mem_info); + + /* Obtain the rkey */ + if (OPAL_UNLIKELY(NULL == mem_info->rkeys[target])) { + /* Create the rkey */ + rc = _tlocal_mem_create_rkey(mem_rec, ep, target); + if (rc) { + return rc; + } + DBG_OUT("_tlocal_fetch: creating rkey ...\n"); + } + + *_ep = ep; + *_rkey = rkey = mem_info->rkeys[target]; + *_winfo = winfo; + + DBG_OUT("_tlocal_fetch(end): ep = %p, rkey = %p, winfo = %p\n", + (void *)ep, (void *)rkey, (void *)winfo); + + return OPAL_SUCCESS; +} + + + +OPAL_DECLSPEC int +opal_common_ucx_mem_putget(opal_common_ucx_mem_t *mem, + opal_common_ucx_op_t op, + int target, void *buffer, size_t len, + uint64_t rem_addr) +{ + ucp_ep_h ep; + ucp_rkey_h rkey; + ucs_status_t status; + _worker_info_t *winfo; + int rc = OPAL_SUCCESS; + + rc =_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + if(OPAL_SUCCESS != rc){ + MCA_COMMON_UCX_VERBOSE(1, "tlocal_fetch failed: %d", rc); + return rc; + } + DBG_OUT("opal_common_ucx_mem_putget(after _tlocal_fetch): mem = %p, ep = %p, rkey = %p, winfo = %p\n", + (void *)mem, (void *)ep, (void *)rkey, (void *)winfo); + + /* Perform the operation */ + opal_mutex_lock(&winfo->mutex); + switch(op){ + case OPAL_COMMON_UCX_PUT: + status = ucp_put_nbi(ep, buffer,len, rem_addr, rkey); + // TODO: movethis duplicated if-else out of switch + // char *func = "ucp_put_nbi"; + // verbose("... func = %s...", func); + if (status != UCS_OK && status != UCS_INPROGRESS) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); + rc = OPAL_ERROR; + } else { + DBG_OUT("opal_common_ucx_mem_putget(after ucp_put_nbi): ep = %p, rkey = %p\n", + (void *)ep, (void *)rkey); + } + break; + case OPAL_COMMON_UCX_GET: + status = ucp_get_nbi(ep, buffer,len, rem_addr, rkey); + if (status != UCS_OK && status != UCS_INPROGRESS) { + MCA_COMMON_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status); + rc = OPAL_ERROR; + } else { + DBG_OUT("opal_common_ucx_mem_putget(after ucp_get_nbi): ep = %p, rkey = %p\n", + (void *)ep, (void *)rkey); + } + break; + } + opal_mutex_unlock(&winfo->mutex); + return rc; +} + + +OPAL_DECLSPEC +int opal_common_ucx_mem_cmpswp(opal_common_ucx_mem_t *mem, + uint64_t compare, uint64_t value, + int target, void *buffer, size_t len, + uint64_t rem_addr) +{ + ucp_ep_h ep; + ucp_rkey_h rkey; + _worker_info_t *winfo = NULL; + ucs_status_t status; + int rc = OPAL_SUCCESS; + + rc =_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + if(OPAL_SUCCESS != rc){ + MCA_COMMON_UCX_VERBOSE(1, "tlocal_fetch failed: %d", rc); + return rc; + } + DBG_OUT("opal_common_ucx_mem_cmpswp(after _tlocal_fetch): mem = %p, ep = %p, rkey = %p, winfo = %p\n", + (void *)mem, (void *)ep, (void *)rkey, (void *)winfo); + + /* Perform the operation */ + opal_mutex_lock(&winfo->mutex); + status = opal_common_ucx_atomic_cswap(ep, compare, value, + buffer, len, + rem_addr, rkey, + winfo->worker); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "opal_common_ucx_atomic_cswap failed: %d", status); + rc = OPAL_ERROR; + } else { + DBG_OUT("opal_common_ucx_mem_cmpswp(after opal_common_ucx_atomic_cswap): ep = %p, rkey = %p\n", + (void *)ep, (void *)rkey); + } + opal_mutex_unlock(&winfo->mutex); + + return rc; +} + +OPAL_DECLSPEC +int opal_common_ucx_mem_fetch(opal_common_ucx_mem_t *mem, + ucp_atomic_fetch_op_t opcode, uint64_t value, + int target, void *buffer, size_t len, + uint64_t rem_addr) +{ + ucp_ep_h ep = NULL; + ucp_rkey_h rkey = NULL; + _worker_info_t *winfo = NULL; + ucs_status_t status; + int rc = OPAL_SUCCESS; + + rc =_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + if(OPAL_SUCCESS != rc){ + MCA_COMMON_UCX_VERBOSE(1, "tlocal_fetch failed: %d", rc); + return rc; + } + DBG_OUT("opal_common_ucx_mem_fetch(after _tlocal_fetch): mem = %p, ep = %p, rkey = %p, winfo = %p\n", + (void *)mem, (void *)ep, (void *)rkey, (void *)winfo); + + /* Perform the operation */ + opal_mutex_lock(&winfo->mutex); + status = opal_common_ucx_atomic_fetch(ep, opcode, value, + buffer, len, + rem_addr, rkey, + winfo->worker); + if (status != UCS_OK) { + MCA_COMMON_UCX_VERBOSE(1, "opal_common_ucx_atomic_fetch failed: %d", status); + rc = OPAL_ERROR; + } else { + DBG_OUT("opal_common_ucx_mem_fetch(after opal_common_ucx_atomic_fetch): ep = %p, rkey = %p\n", + (void *)ep, (void *)rkey); + } + opal_mutex_unlock(&winfo->mutex); + + return rc; +} + +OPAL_DECLSPEC +int opal_common_ucx_mem_fetch_nb(opal_common_ucx_mem_t *mem, + ucp_atomic_fetch_op_t opcode, + uint64_t value, + int target, void *buffer, size_t len, + uint64_t rem_addr, ucs_status_ptr_t *ptr) +{ + ucp_ep_h ep = NULL; + ucp_rkey_h rkey = NULL; + _worker_info_t *winfo = NULL; + int rc = OPAL_SUCCESS; + + rc =_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + if(OPAL_SUCCESS != rc){ + MCA_COMMON_UCX_VERBOSE(1, "tlocal_fetch failed: %d", rc); + return rc; + } + + /* Perform the operation */ + opal_mutex_lock(&winfo->mutex); + (*ptr) = opal_common_ucx_atomic_fetch_nb(ep, opcode, value, + buffer, len, + rem_addr, rkey, + winfo->worker); + opal_mutex_unlock(&winfo->mutex); + + return rc; +} + + +OPAL_DECLSPEC +int opal_common_ucx_mem_post(opal_common_ucx_mem_t *mem, + ucp_atomic_post_op_t opcode, + uint64_t value, int target, size_t len, + uint64_t rem_addr) +{ + ucp_ep_h ep; + ucp_rkey_h rkey; + _worker_info_t *winfo = NULL; + ucs_status_t status; + int rc = OPAL_SUCCESS; + + + rc =_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + if(OPAL_SUCCESS != rc){ + MCA_COMMON_UCX_VERBOSE(1, "tlocal_fetch failed: %d", rc); + return rc; + } + DBG_OUT("opal_common_ucx_mem_post(after _tlocal_fetch): mem = %p, ep = %p, rkey = %p, winfo = %p\n", + (void *)mem, (void *)ep, (void *)rkey, (void *)winfo); + + /* Perform the operation */ + opal_mutex_lock(&winfo->mutex); + status = ucp_atomic_post(ep, opcode, value, + len, rem_addr, rkey); + if (status != UCS_OK) { + MCA_COMMON_UCX_ERROR("ucp_atomic_post failed: %d", status); + rc = OPAL_ERROR; + } else { + DBG_OUT("opal_common_ucx_mem_post(after ucp_atomic_post): ep = %p, rkey = %p\n", (void *)ep, (void *)rkey); + } + opal_mutex_unlock(&winfo->mutex); + return rc; +} + +OPAL_DECLSPEC int +opal_common_ucx_mem_flush(opal_common_ucx_mem_t *mem, + opal_common_ucx_flush_scope_t scope, + int target) +{ + _worker_list_item_t *item; + opal_common_ucx_ctx_t *ctx = mem->ctx; + int rc = OPAL_SUCCESS; + + DBG_OUT("opal_common_ucx_mem_flush: mem = %p, target = %d\n", (void *)mem, target); + + // TODO: make this as a read lock + opal_mutex_lock(&ctx->mutex); + OPAL_LIST_FOREACH(item, &ctx->workers, _worker_list_item_t) { + switch (scope) { + case OPAL_COMMON_UCX_SCOPE_WORKER: + opal_mutex_lock(&item->ptr->winfo->mutex); + rc = opal_common_ucx_worker_flush(item->ptr->winfo->worker); + if (rc != OPAL_SUCCESS) { + MCA_COMMON_UCX_VERBOSE(1, "opal_common_ucx_worker_flush failed: %d", rc); + rc = OPAL_ERROR; + } + DBG_OUT("opal_common_ucx_mem_flush(after opal_common_ucx_worker_flush): worker = %p\n", + (void *)item->ptr->winfo->worker); + opal_mutex_unlock(&item->ptr->winfo->mutex); + break; + case OPAL_COMMON_UCX_SCOPE_EP: + if (NULL != item->ptr->winfo->endpoints[target] ) { + opal_mutex_lock(&item->ptr->winfo->mutex); + rc = opal_common_ucx_ep_flush(item->ptr->winfo->endpoints[target], + item->ptr->winfo->worker); + if (rc != OPAL_SUCCESS) { + MCA_COMMON_UCX_VERBOSE(1, "opal_common_ucx_ep_flush failed: %d", rc); + rc = OPAL_ERROR; + } + DBG_OUT("opal_common_ucx_mem_flush(after opal_common_ucx_worker_flush): ep = %p worker = %p\n", + (void *)item->ptr->winfo->endpoints[target], + (void *)item->ptr->winfo->worker); + opal_mutex_unlock(&item->ptr->winfo->mutex); + } + } + } + opal_mutex_unlock(&ctx->mutex); + + return rc; +} + +OPAL_DECLSPEC +int opal_common_ucx_workers_progress(opal_common_ucx_wpool_t *wpool) { + // TODO + static int enter = 0; + if (enter == 0) { + DBG_OUT("opal_common_ucx_workres_progress: wpool = %p\n", (void *)wpool); + } + + enter++; + return OPAL_SUCCESS; +} + + +OPAL_DECLSPEC int +opal_common_ucx_mem_fence(opal_common_ucx_mem_t *mem) { + /* TODO */ + return OPAL_SUCCESS; +} diff --git a/opal/mca/common/ucx/common_ucx.h b/opal/mca/common/ucx/common_ucx.h index e25dd23b821..e264e315fca 100644 --- a/opal/mca/common/ucx/common_ucx.h +++ b/opal/mca/common/ucx/common_ucx.h @@ -18,6 +18,7 @@ #include #include +#include #include "opal/mca/mca.h" #include "opal/util/output.h" @@ -96,6 +97,150 @@ typedef struct opal_common_ucx_del_proc { extern opal_common_ucx_module_t opal_common_ucx; +typedef struct { + int refcnt; + ucp_context_h ucp_ctx; + opal_mutex_t mutex; + opal_list_t idle_workers; + ucp_worker_h recv_worker; + ucp_address_t *recv_waddr; + size_t recv_waddr_len; + opal_atomic_int32_t cur_ctxid, cur_memid; + opal_list_t tls_list; +} opal_common_ucx_wpool_t; + +typedef struct { + opal_atomic_int32_t ctx_id; + opal_mutex_t mutex; + opal_common_ucx_wpool_t *wpool; /* which wpool this ctx belongs to */ + opal_list_t workers; /* active worker lists */ + char *recv_worker_addrs; + int *recv_worker_displs; + size_t comm_size; +} opal_common_ucx_ctx_t; + +typedef struct { + opal_atomic_int32_t mem_id; + opal_mutex_t mutex; + opal_common_ucx_ctx_t *ctx; /* which ctx this mem_reg belongs to */ + ucp_mem_h memh; + opal_list_t registrations; /* mem region lists */ + char *mem_addrs; + int *mem_displs; +} opal_common_ucx_mem_t; + +typedef enum { + OPAL_COMMON_UCX_PUT, + OPAL_COMMON_UCX_GET +} opal_common_ucx_op_t; + +typedef enum { + OPAL_COMMON_UCX_SCOPE_EP, + OPAL_COMMON_UCX_SCOPE_WORKER +} opal_common_ucx_flush_scope_t; + +typedef enum { + OPAL_COMMON_UCX_MEM_ALLOCATE_MAP, + OPAL_COMMON_UCX_MEM_MAP +} opal_common_ucx_mem_type_t; + +typedef int (*opal_common_ucx_exchange_func_t)(void *my_info, size_t my_info_len, + char **recv_info, int **disps, + void *metadata); + +OPAL_DECLSPEC opal_common_ucx_wpool_t * opal_common_ucx_wpool_allocate(void); +OPAL_DECLSPEC void opal_common_ucx_wpool_free(opal_common_ucx_wpool_t *wpool); +OPAL_DECLSPEC int opal_common_ucx_wpool_init(opal_common_ucx_wpool_t *wpool, + int proc_world_size, + ucp_request_init_callback_t req_init_ptr, + size_t req_size, bool enable_mt); +OPAL_DECLSPEC void opal_common_ucx_wpool_finalize(opal_common_ucx_wpool_t *wpool); +OPAL_DECLSPEC int opal_common_ucx_ctx_create(opal_common_ucx_wpool_t *wpool, int comm_size, + opal_common_ucx_exchange_func_t exchange_func, + void *exchange_metadata, + opal_common_ucx_ctx_t **ctx_ptr); +OPAL_DECLSPEC void opal_common_ucx_ctx_release(opal_common_ucx_ctx_t *ctx); +OPAL_DECLSPEC int opal_common_ucx_mem_create(opal_common_ucx_ctx_t *ctx, int comm_size, + void **mem_base, size_t mem_size, + opal_common_ucx_mem_type_t mem_type, + opal_common_ucx_exchange_func_t exchange_func, + void *exchange_metadata, + opal_common_ucx_mem_t **mem_ptr); +OPAL_DECLSPEC int opal_common_ucx_mem_flush(opal_common_ucx_mem_t *mem, + opal_common_ucx_flush_scope_t scope, + int target); +OPAL_DECLSPEC int opal_common_ucx_mem_fetch_nb(opal_common_ucx_mem_t *mem, + ucp_atomic_fetch_op_t opcode, + uint64_t value, + int target, void *buffer, size_t len, + uint64_t rem_addr, ucs_status_ptr_t *ptr); +OPAL_DECLSPEC int opal_common_ucx_mem_fence(opal_common_ucx_mem_t *mem); +OPAL_DECLSPEC int opal_common_ucx_workers_progress(opal_common_ucx_wpool_t *wpool); +OPAL_DECLSPEC int opal_common_ucx_mem_cmpswp(opal_common_ucx_mem_t *mem, + uint64_t compare, uint64_t value, + int target, + void *buffer, size_t len, + uint64_t rem_addr); +OPAL_DECLSPEC int opal_common_ucx_mem_putget(opal_common_ucx_mem_t *mem, + opal_common_ucx_op_t op, + int target, + void *buffer, size_t len, + uint64_t rem_addr); +OPAL_DECLSPEC int opal_common_ucx_mem_fetch(opal_common_ucx_mem_t *mem, + ucp_atomic_fetch_op_t opcode, uint64_t value, + int target, + void *buffer, size_t len, + uint64_t rem_addr); +OPAL_DECLSPEC int opal_common_ucx_mem_post(opal_common_ucx_mem_t *mem, + ucp_atomic_post_op_t opcode, + uint64_t value, + int target, + size_t len, + uint64_t rem_addr); + +#define FDBG +#ifdef FDBG +extern __thread FILE *tls_pf; +extern __thread int initialized; + +#include +#include +#include +#include + +static inline void init_tls_dbg(void) +{ + if( !initialized ) { + int tid = syscall(__NR_gettid); + char hname[128]; + gethostname(hname, 127); + char fname[128]; + + sprintf(fname, "%s.%d.log", hname, tid); + tls_pf = fopen(fname, "w"); + initialized = 1; + } +} + +#define DBG_OUT(...) \ +{ \ + struct timeval start_; \ + time_t nowtime_; \ + struct tm *nowtm_; \ + char tmbuf_[64]; \ + gettimeofday(&start_, NULL); \ + nowtime_ = start_.tv_sec; \ + nowtm_ = localtime(&nowtime_); \ + strftime(tmbuf_, sizeof(tmbuf_), "%H:%M:%S", nowtm_); \ + init_tls_dbg(); \ + fprintf(tls_pf, "[%s.%06ld] ", tmbuf_, start_.tv_usec);\ + fprintf(tls_pf, __VA_ARGS__); \ +} + +#else +#define DBG_OUT(...) +#endif + OPAL_DECLSPEC void opal_common_ucx_mca_register(void); OPAL_DECLSPEC void opal_common_ucx_mca_deregister(void); OPAL_DECLSPEC void opal_common_ucx_empty_complete_cb(void *request, ucs_status_t status); @@ -165,6 +310,16 @@ int opal_common_ucx_worker_flush(ucp_worker_h worker) #endif } +static inline +ucs_status_ptr_t opal_common_ucx_atomic_fetch_nb(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, + uint64_t value, void *result, size_t op_size, + uint64_t remote_addr, ucp_rkey_h rkey, + ucp_worker_h worker) +{ + return ucp_atomic_fetch_nb(ep, opcode, value, result, op_size, + remote_addr, rkey, opal_common_ucx_empty_complete_cb); +} + static inline int opal_common_ucx_atomic_fetch(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, uint64_t value, void *result, size_t op_size, @@ -173,8 +328,8 @@ int opal_common_ucx_atomic_fetch(ucp_ep_h ep, ucp_atomic_fetch_op_t opcode, { ucs_status_ptr_t request; - request = ucp_atomic_fetch_nb(ep, opcode, value, result, op_size, - remote_addr, rkey, opal_common_ucx_empty_complete_cb); + request = opal_common_ucx_atomic_fetch_nb(ep, opcode, value, result, op_size, + remote_addr, rkey, worker); return opal_common_ucx_wait_request(request, worker, "ucp_atomic_fetch_nb"); }