Skip to content

Topic/osc mt (modified) #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contrib/platform/mellanox/optimized
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
enable_mca_no_build=coll-ml,btl-uct
enable_mca_no_build=coll-ml
enable_debug_symbols=yes
enable_orterun_prefix_by_default=yes
with_verbs=no
Expand Down
6 changes: 6 additions & 0 deletions ompi/mca/osc/ucx/osc_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ typedef struct ompi_osc_ucx_component {
int num_incomplete_req_ops;
int num_modules;
unsigned int priority;
pthread_mutex_t worker_mutex;
char *worker_addr_buf;
int *worker_addr_disps;
char *mem_addr_buf;
int *mem_addr_disps;
pthread_t main_tid;
} ompi_osc_ucx_component_t;

OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component;
Expand Down
75 changes: 66 additions & 9 deletions ompi/mca/osc/ucx/osc_ucx_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ typedef struct ucx_iovec {
size_t len;
} ucx_iovec_t;

OBJ_CLASS_INSTANCE(thread_local_info_t, opal_list_item_t, NULL, NULL);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Has to be in the common/ucx *.c file
  • let's create constructor/destructor


pthread_key_t my_thread_key = {0};

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This has to be declared in common/ucx/ *.c file
  • Use opal_common_ucx_ prefix
  • extern declaration has to be in common/ucx


static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target,
bool is_req_ops) {
if (is_req_ops == false) {
Expand Down Expand Up @@ -83,6 +87,40 @@ static inline int incr_and_check_ops_num(ompi_osc_ucx_module_t *module, int targ
return OMPI_SUCCESS;
}

static inline int get_osc_metadata(ompi_osc_ucx_module_t *module, int target,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the next step:

  • we will need this function to be in the common code so both OSC and OSHMEM would be able to use it.
  • Will need to unify main thread and other threads
  • need to record/recognize main thread in the common code.

ucp_ep_h *ep_ptr, ucp_rkey_h *rkey_ptr,
pthread_mutex_t **mutex_pptr) {
pthread_t tid = pthread_self();
int ret = OMPI_SUCCESS;

if (pthread_equal(tid, mca_osc_ucx_component.main_tid)) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now: move main_tid into the common code

(*ep_ptr) = OSC_UCX_GET_EP(module->comm, target);
(*rkey_ptr) = (module->win_info_array[target]).rkey;
(*mutex_pptr) = &mca_osc_ucx_component.worker_mutex;
} else {
thread_local_info_t *curr_thread_info = NULL;
if ((curr_thread_info = pthread_getspecific(my_thread_key)) == NULL) {
ret = opal_common_ucx_create_local_worker(mca_osc_ucx_component.ucp_context,
ompi_comm_size(module->comm),
mca_osc_ucx_component.worker_addr_buf,
mca_osc_ucx_component.worker_addr_disps,
mca_osc_ucx_component.mem_addr_buf,
mca_osc_ucx_component.mem_addr_disps);
if (ret != OMPI_SUCCESS) {
return ret;
}
curr_thread_info = pthread_getspecific(my_thread_key);
}

assert(curr_thread_info != NULL);
(*rkey_ptr) = curr_thread_info->rkeys[target];
(*ep_ptr) = curr_thread_info->eps[target];
(*mutex_pptr) = &curr_thread_info->lock;
}

return ret;
}

static inline int create_iov_list(const void *addr, int count, ompi_datatype_t *datatype,
ucx_iovec_t **ucx_iov, uint32_t *ucx_iov_count) {
int ret = OMPI_SUCCESS;
Expand Down Expand Up @@ -139,7 +177,8 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
bool is_origin_contig, ptrdiff_t origin_lb,
int target, ucp_ep_h ep, uint64_t remote_addr, ucp_rkey_h rkey,
int target_count, struct ompi_datatype_t *target_dt,
bool is_target_contig, ptrdiff_t target_lb, bool is_get) {
bool is_target_contig, ptrdiff_t target_lb, bool is_get,
pthread_mutex_t *mutex_ptr) {
ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL;
uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0;
uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0;
Expand Down Expand Up @@ -168,6 +207,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
curr_len = MIN(origin_ucx_iov[origin_ucx_iov_idx].len,
target_ucx_iov[target_ucx_iov_idx].len);

pthread_mutex_lock(mutex_ptr);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This block goes to common
  • Make sure to always unlock (including the error path)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has to be an inline function with argument put/get that will allow static compile optimization

if (!is_get) {
status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len,
remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), rkey);
Expand All @@ -183,6 +223,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
return OMPI_ERROR;
}
}
pthread_mutex_unlock(mutex_ptr);

ret = incr_and_check_ops_num(module, target, ep);
if (ret != OMPI_SUCCESS) {
Expand All @@ -208,6 +249,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
} else if (!is_origin_contig) {
size_t prev_len = 0;
while (origin_ucx_iov_idx < origin_ucx_iov_count) {
pthread_mutex_lock(mutex_ptr);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This block goes to common
  • See above comment about locks

if (!is_get) {
status = ucp_put_nbi(ep, origin_ucx_iov[origin_ucx_iov_idx].addr,
origin_ucx_iov[origin_ucx_iov_idx].len,
Expand All @@ -225,6 +267,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
return OMPI_ERROR;
}
}
pthread_mutex_unlock(mutex_ptr);

ret = incr_and_check_ops_num(module, target, ep);
if (ret != OMPI_SUCCESS) {
Expand All @@ -237,6 +280,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
} else {
size_t prev_len = 0;
while (target_ucx_iov_idx < target_ucx_iov_count) {
pthread_mutex_lock(mutex_ptr);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This block also goes to the common
  • see above about locks

if (!is_get) {
status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb + prev_len),
target_ucx_iov[target_ucx_iov_idx].len,
Expand All @@ -254,6 +298,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module,
return OMPI_ERROR;
}
}
pthread_mutex_unlock(mutex_ptr);

ret = incr_and_check_ops_num(module, target, ep);
if (ret != OMPI_SUCCESS) {
Expand Down Expand Up @@ -367,19 +412,25 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
int target, ptrdiff_t target_disp, int target_count,
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
ucp_ep_h ep;
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
ucp_rkey_h rkey;
bool is_origin_contig = false, is_target_contig = false;
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
ucs_status_t status;
pthread_mutex_t *mutex_ptr = NULL;
int ret = OMPI_SUCCESS;

ret = check_sync_state(module, target, false);
if (ret != OMPI_SUCCESS) {
return ret;
}

ret = get_osc_metadata(module, target, &ep, &rkey, &mutex_ptr);
if (ret != OMPI_SUCCESS) {
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
Expand All @@ -393,8 +444,6 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
return OMPI_SUCCESS;
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);

Expand All @@ -408,17 +457,19 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data
ompi_datatype_type_size(origin_dt, &origin_len);
origin_len *= origin_count;

pthread_mutex_lock(mutex_ptr);
status = ucp_put_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len,
remote_addr + target_lb, rkey);
if (status != UCS_OK && status != UCS_INPROGRESS) {
OSC_UCX_VERBOSE(1, "ucp_put_nbi failed: %d", status);
return OMPI_ERROR;
}
pthread_mutex_unlock(mutex_ptr);
return incr_and_check_ops_num(module, target, ep);
} else {
return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig,
origin_lb, target, ep, remote_addr, rkey, target_count, target_dt,
is_target_contig, target_lb, false);
is_target_contig, target_lb, false, mutex_ptr);
}
}

Expand All @@ -427,19 +478,25 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
int target, ptrdiff_t target_disp, int target_count,
struct ompi_datatype_t *target_dt, struct ompi_win_t *win) {
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h ep = OSC_UCX_GET_EP(module->comm, target);
ucp_ep_h ep;
uint64_t remote_addr = (module->win_info_array[target]).addr + target_disp * OSC_UCX_GET_DISP(module, target);
ucp_rkey_h rkey;
ptrdiff_t origin_lb, origin_extent, target_lb, target_extent;
bool is_origin_contig = false, is_target_contig = false;
ucs_status_t status;
pthread_mutex_t *mutex_ptr = NULL;
int ret = OMPI_SUCCESS;

ret = check_sync_state(module, target, false);
if (ret != OMPI_SUCCESS) {
return ret;
}

ret = get_osc_metadata(module, target, &ep, &rkey, &mutex_ptr);
if (ret != OMPI_SUCCESS) {
return ret;
}

if (module->flavor == MPI_WIN_FLAVOR_DYNAMIC) {
status = get_dynamic_win_info(remote_addr, module, ep, target);
if (status != UCS_OK) {
Expand All @@ -453,8 +510,6 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
return OMPI_SUCCESS;
}

rkey = (module->win_info_array[target]).rkey;

ompi_datatype_get_true_extent(origin_dt, &origin_lb, &origin_extent);
ompi_datatype_get_true_extent(target_dt, &target_lb, &target_extent);

Expand All @@ -468,18 +523,20 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count,
ompi_datatype_type_size(origin_dt, &origin_len);
origin_len *= origin_count;

pthread_mutex_lock(mutex_ptr);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code block goes into the common code.

status = ucp_get_nbi(ep, (void *)((intptr_t)origin_addr + origin_lb), origin_len,
remote_addr + target_lb, rkey);
if (status != UCS_OK && status != UCS_INPROGRESS) {
OSC_UCX_VERBOSE(1, "ucp_get_nbi failed: %d", status);
return OMPI_ERROR;
}
pthread_mutex_unlock(mutex_ptr);

return incr_and_check_ops_num(module, target, ep);
} else {
return ddt_put_get(module, origin_addr, origin_count, origin_dt, is_origin_contig,
origin_lb, target, ep, remote_addr, rkey, target_count, target_dt,
is_target_contig, target_lb, true);
is_target_contig, target_lb, true, mutex_ptr);
}
}

Expand Down
Loading