Skip to content

Commit 7dd3cef

Browse files
committed
Allow MPI_WIN_SHARED_QUERY on regular windows
MPI 4.0 introduced allows applications to query regular windows for shared memory. This patch enables it for osc/rdma and osc/ucx and otherwise makes sure we fail gracefully if the component does not provide the query callback. For osc/rdma, this is currently supported only for allocated windows but could later be extended to windows with application-provided memory through xpmem. Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
1 parent 0ff5e81 commit 7dd3cef

File tree

3 files changed

+122
-13
lines changed

3 files changed

+122
-13
lines changed

ompi/mca/osc/rdma/osc_rdma_component.c

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
#include "ompi/mca/bml/base/base.h"
7070
#include "ompi/mca/mtl/base/base.h"
7171

72+
static int ompi_osc_rdma_shared_query(struct ompi_win_t *win, int rank, size_t *size,
73+
ptrdiff_t *disp_unit, void *baseptr);
7274
static int ompi_osc_rdma_component_register (void);
7375
static int ompi_osc_rdma_component_init (bool enable_progress_threads, bool enable_mpi_threads);
7476
static int ompi_osc_rdma_component_finalize (void);
@@ -113,6 +115,7 @@ ompi_osc_rdma_component_t mca_osc_rdma_component = {
113115
MCA_BASE_COMPONENT_INIT(ompi, osc, rdma)
114116

115117
ompi_osc_base_module_t ompi_osc_rdma_module_rdma_template = {
118+
.osc_win_shared_query = ompi_osc_rdma_shared_query,
116119
.osc_win_attach = ompi_osc_rdma_attach,
117120
.osc_win_detach = ompi_osc_rdma_detach,
118121
.osc_free = ompi_osc_rdma_free,
@@ -898,7 +901,7 @@ static void ompi_osc_rdma_ensure_local_add_procs (void)
898901
/* this will cause add_proc to get called if it has not already been called */
899902
(void) mca_bml_base_get_endpoint (proc);
900903
}
901-
}
904+
}
902905

903906
free(procs);
904907
}
@@ -1632,3 +1635,58 @@ ompi_osc_rdma_set_no_lock_info(opal_infosubscriber_t *obj, const char *key, cons
16321635
*/
16331636
return module->no_locks ? "true" : "false";
16341637
}
1638+
1639+
int ompi_osc_rdma_shared_query(
1640+
struct ompi_win_t *win, int rank, size_t *size,
1641+
ptrdiff_t *disp_unit, void *baseptr)
1642+
{
1643+
ompi_osc_rdma_peer_t *peer;
1644+
int actual_rank = rank;
1645+
ompi_osc_rdma_module_t *module = GET_MODULE(win);
1646+
1647+
peer = ompi_osc_rdma_module_peer (module, actual_rank);
1648+
if (NULL == peer) {
1649+
return OMPI_ERR_BAD_PARAM;
1650+
}
1651+
1652+
/* currently only supported for allocated windows */
1653+
if (MPI_WIN_FLAVOR_ALLOCATE != module->flavor) {
1654+
return OMPI_ERR_NOT_SUPPORTED;
1655+
}
1656+
1657+
if (!ompi_osc_rdma_peer_local_base(peer)) {
1658+
return OMPI_ERR_NOT_SUPPORTED;
1659+
}
1660+
1661+
if (MPI_PROC_NULL == rank) {
1662+
/* iterate until we find a rank that has a non-zero size */
1663+
for (int i = 0 ; i < ompi_comm_size(module->comm) ; ++i) {
1664+
peer = ompi_osc_rdma_module_peer (module, i);
1665+
ompi_osc_rdma_peer_extended_t *ex_peer = (ompi_osc_rdma_peer_extended_t *) peer;
1666+
if (!ompi_osc_rdma_peer_local_base(peer)) {
1667+
continue;
1668+
} else if (module->same_size && ex_peer->super.base) {
1669+
break;
1670+
} else if (ex_peer->size > 0) {
1671+
break;
1672+
}
1673+
}
1674+
}
1675+
1676+
if (module->same_size && module->same_disp_unit) {
1677+
*size = module->size;
1678+
*disp_unit = module->disp_unit;
1679+
ompi_osc_rdma_peer_basic_t *ex_peer = (ompi_osc_rdma_peer_basic_t *) peer;
1680+
*((void**) baseptr) = (void *) (intptr_t)ex_peer->base;
1681+
} else {
1682+
ompi_osc_rdma_peer_extended_t *ex_peer = (ompi_osc_rdma_peer_extended_t *) peer;
1683+
if (ex_peer->super.base != 0) {
1684+
/* we know the base of the peer */
1685+
*((void**) baseptr) = (void *) (intptr_t)ex_peer->super.base;
1686+
*size = ex_peer->size;
1687+
*disp_unit = ex_peer->disp_unit;
1688+
return OMPI_SUCCESS;
1689+
}
1690+
}
1691+
return OMPI_ERR_NOT_SUPPORTED;
1692+
}

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,30 +468,70 @@ static const char* ompi_osc_ucx_set_no_lock_info(opal_infosubscriber_t *obj, con
468468
return module->no_locks ? "true" : "false";
469469
}
470470

471+
static int ompi_osc_ucx_shared_query_peer(ompi_osc_ucx_module_t *module, int rank, size_t *size,
472+
ptrdiff_t *disp_unit, void *baseptr) {
473+
474+
ucp_ep_h *dflt_ep;
475+
ucp_ep_h ep; // ignored
476+
OSC_UCX_GET_DEFAULT_EP(dflt_ep, module, peer); // TODO: needed?
477+
ucs_status_t status;
478+
opal_common_ucx_winfo_t *winfo; // ignored
479+
rc = opal_common_ucx_tlocal_fetch(module->mem, peer, &ep, &rkey, &winfo, dflt_ep);
480+
if (OMPI_SUCCESS != rc) {
481+
return rc;
482+
}
483+
uint64_t raddr;
484+
void *addr_p;
485+
if (UCS_OK != ucp_rkey_ptr(rkey, module->addrs[peer], &addr_p)) {
486+
return OMPI_ERR_NOT_AVAILABLE;
487+
}
488+
*size = module->sizes[i];
489+
*((void**) baseptr) = (void *)module->shmem_addrs[i];
490+
*disp_unit = module->disp_units[i];
491+
492+
return OMPI_SUCCESS;
493+
}
494+
471495
int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size,
472496
ptrdiff_t *disp_unit, void *baseptr)
473497
{
474498
ompi_osc_ucx_module_t *module =
475499
(ompi_osc_ucx_module_t*) win->w_osc_module;
476500

501+
*size = 0;
502+
*((void**) baseptr) = NULL;
503+
*disp_unit = 0;
504+
477505
if (module->flavor != MPI_WIN_FLAVOR_SHARED) {
478-
return MPI_ERR_WIN;
479-
}
480506

481-
if (MPI_PROC_NULL != rank) {
507+
if (MPI_PROC_NULL == rank) {
508+
for (int i = 0 ; i < ompi_comm_size(module->comm) ; ++i) {
509+
if (0 != module->sizes[i]) {
510+
if (OMPI_SUCCESS == ompi_osc_ucx_shared_query_peer(module, i, size, disp_unit, baseptr)) {
511+
return OMPI_SUCCESS;
512+
}
513+
}
514+
}
515+
} else {
516+
if (0 != module->sizes[i]) {
517+
if (OMPI_SUCCESS == ompi_osc_ucx_shared_query_peer(module, i, size, disp_unit, baseptr)) {
518+
return OMPI_SUCCESS;
519+
}
520+
}
521+
}
522+
return OMPI_ERR_NOT_SUPPORTED;
523+
524+
} else if (MPI_PROC_NULL != rank) { // shared memory window with given rank
482525
*size = module->sizes[rank];
483526
*((void**) baseptr) = (void *)module->shmem_addrs[rank];
484527
if (module->disp_unit == -1) {
485528
*disp_unit = module->disp_units[rank];
486529
} else {
487530
*disp_unit = module->disp_unit;
488531
}
489-
} else {
532+
} else { // shared memory window with MPI_PROC_NULL
490533
int i = 0;
491534

492-
*size = 0;
493-
*((void**) baseptr) = NULL;
494-
*disp_unit = 0;
495535
for (i = 0 ; i < ompi_comm_size(module->comm) ; ++i) {
496536
if (0 != module->sizes[i]) {
497537
*size = module->sizes[i];

ompi/mpi/c/win_shared_query.c.in

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626

2727
PROTOTYPE ERROR_CLASS win_shared_query(WIN win, INT rank, AINT_OUT size, INT_AINT_OUT disp_unit, BUFFER_OUT baseptr)
2828
{
29-
int rc;
3029
size_t tsize;
3130
ptrdiff_t du;
31+
int rc = OMPI_SUCCESS;
3232

3333
if (MPI_PARAM_CHECK) {
3434
OMPI_ERR_INIT_FINALIZE(FUNC_NAME);
@@ -40,12 +40,23 @@ PROTOTYPE ERROR_CLASS win_shared_query(WIN win, INT rank, AINT_OUT size, INT_AIN
4040
}
4141
}
4242

43+
rc = OMPI_ERR_NOT_SUPPORTED;
44+
4345
if (NULL != win->w_osc_module->osc_win_shared_query) {
4446
rc = win->w_osc_module->osc_win_shared_query(win, rank, &tsize, &du, baseptr);
45-
*size = tsize;
46-
*disp_unit = du;
47-
} else {
48-
rc = MPI_ERR_RMA_FLAVOR;
47+
if (OMPI_SUCCESS == rc) {
48+
*size = tsize;
49+
*disp_unit = du;
50+
}
51+
}
52+
53+
if (OMPI_ERR_NOT_SUPPORTED == rc) {
54+
/* gracefully bail out */
55+
*size = 0;
56+
*disp_unit = 0;
57+
*(void**) baseptr = NULL;
58+
rc = MPI_SUCCESS; // don't raise an error if the function is not supported
4959
}
60+
5061
OMPI_ERRHANDLER_RETURN(rc, win, rc, FUNC_NAME);
5162
}

0 commit comments

Comments
 (0)