-
Notifications
You must be signed in to change notification settings - Fork 0
Topic/osc mt (modified) #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,10 @@ typedef struct ucx_iovec { | |
size_t len; | ||
} ucx_iovec_t; | ||
|
||
OBJ_CLASS_INSTANCE(thread_local_info_t, opal_list_item_t, NULL, NULL); | ||
|
||
pthread_key_t my_thread_key = {0}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target, | ||
bool is_req_ops) { | ||
if (is_req_ops == false) { | ||
|
@@ -83,6 +87,40 @@ static inline int incr_and_check_ops_num(ompi_osc_ucx_module_t *module, int targ | |
return OMPI_SUCCESS; | ||
} | ||
|
||
static inline int get_osc_metadata(ompi_osc_ucx_module_t *module, int target, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the next step:
|
||
ucp_ep_h *ep_ptr, ucp_rkey_h *rkey_ptr, | ||
pthread_mutex_t **mutex_pptr) { | ||
pthread_t tid = pthread_self(); | ||
int ret = OMPI_SUCCESS; | ||
|
||
if (pthread_equal(tid, mca_osc_ucx_component.main_tid)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for now: move main_tid into the common code |
||
(*ep_ptr) = OSC_UCX_GET_EP(module->comm, target); | ||
(*rkey_ptr) = (module->win_info_array[target]).rkey; | ||
(*mutex_pptr) = &mca_osc_ucx_component.worker_mutex; | ||
} else { | ||
thread_local_info_t *curr_thread_info = NULL; | ||
if ((curr_thread_info = pthread_getspecific(my_thread_key)) == NULL) { | ||
ret = opal_common_ucx_create_local_worker(mca_osc_ucx_component.ucp_context, | ||
ompi_comm_size(module->comm), | ||
mca_osc_ucx_component.worker_addr_buf, | ||
mca_osc_ucx_component.worker_addr_disps, | ||
mca_osc_ucx_component.mem_addr_buf, | ||
mca_osc_ucx_component.mem_addr_disps); | ||
if (ret != OMPI_SUCCESS) { | ||
return ret; | ||
} | ||
curr_thread_info = pthread_getspecific(my_thread_key); | ||
} | ||
|
||
assert(curr_thread_info != NULL); | ||
(*rkey_ptr) = curr_thread_info->rkeys[target]; | ||
(*ep_ptr) = curr_thread_info->eps[target]; | ||
(*mutex_pptr) = &curr_thread_info->lock; | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
static inline int create_iov_list(const void *addr, int count, ompi_datatype_t *datatype, | ||
ucx_iovec_t **ucx_iov, uint32_t *ucx_iov_count) { | ||
int ret = OMPI_SUCCESS; | ||
|
@@ -139,7 +177,8 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
bool is_origin_contig, ptrdiff_t origin_lb, | ||
int target, ucp_ep_h ep, uint64_t remote_addr, ucp_rkey_h rkey, | ||
int target_count, struct ompi_datatype_t *target_dt, | ||
bool is_target_contig, ptrdiff_t target_lb, bool is_get) { | ||
bool is_target_contig, ptrdiff_t target_lb, bool is_get, | ||
pthread_mutex_t *mutex_ptr) { | ||
ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL; | ||
uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0; | ||
uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0; | ||
|
@@ -168,6 +207,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
curr_len = MIN(origin_ucx_iov[origin_ucx_iov_idx].len, | ||
target_ucx_iov[target_ucx_iov_idx].len); | ||
|
||
pthread_mutex_lock(mutex_ptr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this has to be an inline function with argument put/get that will allow static compile optimization |
||
if (!is_get) { | ||
status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len, | ||
remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey); | ||
|
@@ -183,6 +223,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
return OMPI_ERROR; | ||
} | ||
} | ||
pthread_mutex_unlock(mutex_ptr); | ||
|
||
ret = incr_and_check_ops_num(module, target, ep); | ||
if (ret != OMPI_SUCCESS) { | ||
|
@@ -208,6 +249,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
} else if (!is_origin_contig) { | ||
size_t prev_len = 0; | ||
while (origin_ucx_iov_idx < origin_ucx_iov_count) { | ||
pthread_mutex_lock(mutex_ptr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if (!is_get) { | ||
status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, | ||
origin_ucx_iov[origin_ucx_iov_idx].len, | ||
|
@@ -225,6 +267,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
return OMPI_ERROR; | ||
} | ||
} | ||
pthread_mutex_unlock(mutex_ptr); | ||
|
||
ret = incr_and_check_ops_num(module, target, ep); | ||
if (ret != OMPI_SUCCESS) { | ||
|
@@ -237,6 +280,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
} else { | ||
size_t prev_len = 0; | ||
while (target_ucx_iov_idx < target_ucx_iov_count) { | ||
pthread_mutex_lock(mutex_ptr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if (!is_get) { | ||
status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb + prev_len), | ||
target_ucx_iov[target_ucx_iov_idx].len, | ||
|
@@ -254,6 +298,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, | |
return OMPI_ERROR; | ||
} | ||
} | ||
pthread_mutex_unlock(mutex_ptr); | ||
|
||
ret = incr_and_check_ops_num(module, target, ep); | ||
if (ret != OMPI_SUCCESS) { | ||
|
@@ -367,19 +412,25 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data | |
int target, ptrdiff_t target_disp, int target_count, | ||
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { | ||
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; | ||
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); | ||
ucp_ep_h ep; | ||
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); | ||
ucp_rkey_h rkey; | ||
bool is_origin_contig = false, is_target_contig = false; | ||
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; | ||
ucs_status_t status; | ||
pthread_mutex_t *mutex_ptr = NULL; | ||
int ret = OMPI_SUCCESS; | ||
|
||
ret = check_sync_state(module, target, false); | ||
if (ret != OMPI_SUCCESS) { | ||
return ret; | ||
} | ||
|
||
ret = get_osc_metadata(module, target, &ep, &rkey, &mutex_ptr); | ||
if (ret != OMPI_SUCCESS) { | ||
return ret; | ||
} | ||
|
||
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { | ||
status = get_dynamic_win_info(remote_addr, module, ep, target); | ||
if (status != UCS_OK) { | ||
|
@@ -393,8 +444,6 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data | |
return OMPI_SUCCESS; | ||
} | ||
|
||
rkey = (module->win_info_array[target]).rkey; | ||
|
||
ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); | ||
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent); | ||
|
||
|
@@ -408,17 +457,19 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data | |
ompi_datatype_type_size(origin_dt, &origin_len); | ||
origin_len *= origin_count; | ||
|
||
pthread_mutex_lock(mutex_ptr); | ||
status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len, | ||
remote_addr + target_lb, rkey); | ||
if (status != UCS_OK && status != UCS_INPROGRESS) { | ||
OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status); | ||
return OMPI_ERROR; | ||
} | ||
pthread_mutex_unlock(mutex_ptr); | ||
return incr_and_check_ops_num(module, target, ep); | ||
} else { | ||
return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig, | ||
origin_lb, target, ep, remote_addr, rkey, target_count, target_dt, | ||
is_target_contig, target_lb, false); | ||
is_target_contig, target_lb, false, mutex_ptr); | ||
} | ||
} | ||
|
||
|
@@ -427,19 +478,25 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, | |
int target, ptrdiff_t target_disp, int target_count, | ||
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { | ||
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; | ||
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target); | ||
ucp_ep_h ep; | ||
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target); | ||
ucp_rkey_h rkey; | ||
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; | ||
bool is_origin_contig = false, is_target_contig = false; | ||
ucs_status_t status; | ||
pthread_mutex_t *mutex_ptr = NULL; | ||
int ret = OMPI_SUCCESS; | ||
|
||
ret = check_sync_state(module, target, false); | ||
if (ret != OMPI_SUCCESS) { | ||
return ret; | ||
} | ||
|
||
ret = get_osc_metadata(module, target, &ep, &rkey, &mutex_ptr); | ||
if (ret != OMPI_SUCCESS) { | ||
return ret; | ||
} | ||
|
||
if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { | ||
status = get_dynamic_win_info(remote_addr, module, ep, target); | ||
if (status != UCS_OK) { | ||
|
@@ -453,8 +510,6 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, | |
return OMPI_SUCCESS; | ||
} | ||
|
||
rkey = (module->win_info_array[target]).rkey; | ||
|
||
ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent); | ||
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent); | ||
|
||
|
@@ -468,18 +523,20 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, | |
ompi_datatype_type_size(origin_dt, &origin_len); | ||
origin_len *= origin_count; | ||
|
||
pthread_mutex_lock(mutex_ptr); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code block goes into the common code. |
||
status = ucp_get_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len, | ||
remote_addr + target_lb, rkey); | ||
if (status != UCS_OK && status != UCS_INPROGRESS) { | ||
OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status); | ||
return OMPI_ERROR; | ||
} | ||
pthread_mutex_unlock(mutex_ptr); | ||
|
||
return incr_and_check_ops_num(module, target, ep); | ||
} else { | ||
return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig, | ||
origin_lb, target, ep, remote_addr, rkey, target_count, target_dt, | ||
is_target_contig, target_lb, true); | ||
is_target_contig, target_lb, true, mutex_ptr); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.