Skip to content

Commit b61bf9a

Browse files
authored
Merge pull request #7349 from hoopoepg/topic/ucx-new-api-nbx
OPAL/UCX: enabling new API provided by UCX
2 parents 0e17e5b + 75bda25 commit b61bf9a

File tree

6 files changed

+284
-43
lines changed

6 files changed

+284
-43
lines changed

config/ompi_check_ucx.m4

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ AC_DEFUN([OMPI_CHECK_UCX],[
128128
[AC_DEFINE([HAVE_UCP_WORKER_ADDRESS_FLAGS], [1],
129129
[have worker address attribute])], [],
130130
[#include <ucp/api/ucp.h>])
131+
AC_CHECK_DECLS([ucp_tag_send_nbx,
132+
ucp_tag_send_sync_nbx,
133+
ucp_tag_recv_nbx],
134+
[], [],
135+
[#include <ucp/api/ucp.h>])
136+
AC_CHECK_TYPES([ucp_request_param_t],
137+
[], [],
138+
[[#include <ucp/api/ucp.h>]])
131139
CPPFLAGS=$old_CPPFLAGS
132140

133141
OPAL_SUMMARY_ADD([[Transports]],[[Open UCX]],[$1],[$ompi_check_ucx_happy])])])

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 97 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,10 @@ int mca_pml_ucx_cleanup(void)
363363

364364
static ucp_ep_h mca_pml_ucx_add_proc_common(ompi_proc_t *proc)
365365
{
366+
size_t addrlen = 0;
366367
ucp_ep_params_t ep_params;
367368
ucp_address_t *address;
368369
ucs_status_t status;
369-
size_t addrlen;
370370
ucp_ep_h ep;
371371
int ret;
372372

@@ -418,6 +418,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
418418
return OMPI_SUCCESS;
419419
}
420420

421+
__opal_attribute_always_inline__
421422
static inline ucp_ep_h mca_pml_ucx_get_ep(ompi_communicator_t *comm, int rank)
422423
{
423424
ompi_proc_t *proc_peer = ompi_comm_peer_lookup(comm, rank);
@@ -539,17 +540,28 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
539540
int src, int tag, struct ompi_communicator_t* comm,
540541
struct ompi_request_t **request)
541542
{
543+
#if HAVE_DECL_UCP_TAG_RECV_NBX
544+
pml_ucx_datatype_t *op_data = mca_pml_ucx_get_op_data(datatype);
545+
ucp_request_param_t *param = &op_data->op_param.recv;
546+
#endif
547+
542548
ucp_tag_t ucp_tag, ucp_tag_mask;
543549
ompi_request_t *req;
544550

545551
PML_UCX_TRACE_RECV("irecv request *%p", buf, count, datatype, src, tag, comm,
546552
(void*)request);
547553

548554
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
555+
#if HAVE_DECL_UCP_TAG_RECV_NBX
556+
req = (ompi_request_t*)ucp_tag_recv_nbx(ompi_pml_ucx.ucp_worker, buf,
557+
mca_pml_ucx_get_data_size(op_data, count),
558+
ucp_tag, ucp_tag_mask, param);
559+
#else
549560
req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
550561
mca_pml_ucx_get_datatype(datatype),
551562
ucp_tag, ucp_tag_mask,
552563
mca_pml_ucx_recv_completion);
564+
#endif
553565
if (UCS_PTR_IS_ERR(req)) {
554566
PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
555567
return OMPI_ERROR;
@@ -565,20 +577,34 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
565577
int tag, struct ompi_communicator_t* comm,
566578
ompi_status_public_t* mpi_status)
567579
{
580+
/* coverity[bad_alloc_arithmetic] */
581+
void *req = PML_UCX_REQ_ALLOCA();
582+
#if HAVE_DECL_UCP_TAG_RECV_NBX
583+
pml_ucx_datatype_t *op_data = mca_pml_ucx_get_op_data(datatype);
584+
ucp_request_param_t *recv_param = &op_data->op_param.recv;
585+
ucp_request_param_t param;
586+
587+
param.op_attr_mask = UCP_OP_ATTR_FIELD_REQUEST |
588+
(recv_param->op_attr_mask & UCP_OP_ATTR_FIELD_DATATYPE);
589+
param.datatype = recv_param->datatype;
590+
param.request = req;
591+
#endif
568592
ucp_tag_t ucp_tag, ucp_tag_mask;
569593
ucp_tag_recv_info_t info;
570594
ucs_status_t status;
571-
void *req;
572595

573596
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
574597

575-
/* coverity[bad_alloc_arithmetic] */
576598
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
577-
req = PML_UCX_REQ_ALLOCA();
578-
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
579-
mca_pml_ucx_get_datatype(datatype),
580-
ucp_tag, ucp_tag_mask, req);
581-
599+
#if HAVE_DECL_UCP_TAG_RECV_NBX
600+
ucp_tag_recv_nbx(ompi_pml_ucx.ucp_worker, buf,
601+
mca_pml_ucx_get_data_size(op_data, count),
602+
ucp_tag, ucp_tag_mask, &param);
603+
#else
604+
ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
605+
mca_pml_ucx_get_datatype(datatype),
606+
ucp_tag, ucp_tag_mask, req);
607+
#endif
582608
MCA_COMMON_UCX_PROGRESS_LOOP(ompi_pml_ucx.ucp_worker) {
583609
status = ucp_request_test(req, &info);
584610
if (status != UCS_INPROGRESS) {
@@ -588,6 +614,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
588614
}
589615
}
590616

617+
__opal_attribute_always_inline__
591618
static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
592619
{
593620
switch (mode) {
@@ -709,6 +736,7 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
709736
return NULL;
710737
}
711738

739+
__opal_attribute_always_inline__
712740
static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *buf,
713741
size_t count,
714742
ompi_datatype_t *datatype,
@@ -726,6 +754,32 @@ static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *
726754
}
727755
}
728756

757+
#if HAVE_DECL_UCP_TAG_SEND_NBX
758+
__opal_attribute_always_inline__
759+
static inline ucs_status_ptr_t
760+
mca_pml_ucx_common_send_nbx(ucp_ep_h ep, const void *buf,
761+
size_t count,
762+
ompi_datatype_t *datatype,
763+
ucp_tag_t tag,
764+
mca_pml_base_send_mode_t mode,
765+
ucp_request_param_t *param)
766+
{
767+
pml_ucx_datatype_t *op_data = mca_pml_ucx_get_op_data(datatype);
768+
769+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
770+
return mca_pml_ucx_bsend(ep, buf, count, datatype, tag);
771+
} else if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
772+
return ucp_tag_send_sync_nb(ep, buf, count,
773+
mca_pml_ucx_get_datatype(datatype), tag,
774+
(ucp_send_callback_t)param->cb.send);
775+
} else {
776+
return ucp_tag_send_nbx(ep, buf,
777+
mca_pml_ucx_get_data_size(op_data, count),
778+
tag, param);
779+
}
780+
}
781+
#endif
782+
729783
int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
730784
int dst, int tag, mca_pml_base_send_mode_t mode,
731785
struct ompi_communicator_t* comm,
@@ -744,10 +798,16 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
744798
return OMPI_ERROR;
745799
}
746800

801+
#if HAVE_DECL_UCP_TAG_SEND_NBX
802+
req = (ompi_request_t*)mca_pml_ucx_common_send_nbx(ep, buf, count, datatype,
803+
PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
804+
&mca_pml_ucx_get_op_data(datatype)->op_param.send);
805+
#else
747806
req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
748807
mca_pml_ucx_get_datatype(datatype),
749808
PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
750809
mca_pml_ucx_send_completion);
810+
#endif
751811

752812
if (req == NULL) {
753813
PML_UCX_VERBOSE(8, "returning completed request");
@@ -789,20 +849,40 @@ mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count,
789849
#if HAVE_DECL_UCP_TAG_SEND_NBR
790850
static inline __opal_attribute_always_inline__ int
791851
mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count,
792-
ucp_datatype_t ucx_datatype, ucp_tag_t tag)
793-
852+
ompi_datatype_t *datatype, ucp_tag_t tag)
794853
{
795-
ucs_status_ptr_t req;
796-
ucs_status_t status;
797-
798854
/* coverity[bad_alloc_arithmetic] */
799-
req = PML_UCX_REQ_ALLOCA();
800-
status = ucp_tag_send_nbr(ep, buf, count, ucx_datatype, tag, req);
855+
ucs_status_ptr_t req = PML_UCX_REQ_ALLOCA();
856+
#if HAVE_DECL_UCP_TAG_SEND_NBX
857+
pml_ucx_datatype_t *op_data = mca_pml_ucx_get_op_data(datatype);
858+
ucp_request_param_t param = {
859+
.op_attr_mask = UCP_OP_ATTR_FIELD_REQUEST |
860+
(op_data->op_param.send.op_attr_mask & UCP_OP_ATTR_FIELD_DATATYPE) |
861+
UCP_OP_ATTR_FLAG_FAST_CMPL,
862+
.datatype = op_data->op_param.send.datatype,
863+
.request = req
864+
};
865+
866+
req = ucp_tag_send_nbx(ep, buf,
867+
mca_pml_ucx_get_data_size(op_data, count),
868+
tag, &param);
869+
if (OPAL_LIKELY(req == UCS_OK)) {
870+
return OMPI_SUCCESS;
871+
} else if (UCS_PTR_IS_ERR(req)) {
872+
PML_UCX_ERROR("%s failed: %d, %s", __func__, UCS_PTR_STATUS(req),
873+
ucs_status_string(UCS_PTR_STATUS(req)));
874+
return OPAL_ERROR;
875+
}
876+
#else
877+
ucs_status_t status;
878+
status = ucp_tag_send_nbr(ep, buf, count,
879+
mca_pml_ucx_get_datatype(datatype), tag, req);
801880
if (OPAL_LIKELY(status == UCS_OK)) {
802881
return OMPI_SUCCESS;
803882
}
883+
#endif
804884

805-
MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send", (void)0);
885+
MCA_COMMON_UCX_WAIT_LOOP(req, ompi_pml_ucx.ucp_worker, "ucx send nbr", (void)0);
806886
}
807887
#endif
808888

@@ -823,8 +903,7 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
823903
#if HAVE_DECL_UCP_TAG_SEND_NBR
824904
if (OPAL_LIKELY((MCA_PML_BASE_SEND_BUFFERED != mode) &&
825905
(MCA_PML_BASE_SEND_SYNCHRONOUS != mode))) {
826-
return mca_pml_ucx_send_nbr(ep, buf, count,
827-
mca_pml_ucx_get_datatype(datatype),
906+
return mca_pml_ucx_send_nbr(ep, buf, count, datatype,
828907
PML_UCX_MAKE_SEND_TAG(tag, comm));
829908
}
830909
#endif

ompi/mca/pml/ucx/pml_ucx_datatype.c

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,20 @@
1010
*/
1111

1212
#include "pml_ucx_datatype.h"
13+
#include "pml_ucx_request.h"
1314

1415
#include "ompi/runtime/mpiruntime.h"
1516
#include "ompi/attribute/attribute.h"
1617

1718
#include <inttypes.h>
19+
#include <math.h>
1820

21+
#ifdef HAVE_UCP_REQUEST_PARAM_T
22+
#define PML_UCX_DATATYPE_SET_VALUE(_datatype, _val) \
23+
(_datatype)->op_param.send._val; \
24+
(_datatype)->op_param.bsend._val; \
25+
(_datatype)->op_param.recv._val;
26+
#endif
1927

2028
static void* pml_ucx_generic_datatype_start_pack(void *context, const void *buffer,
2129
size_t count)
@@ -135,30 +143,77 @@ int mca_pml_ucx_datatype_attr_del_fn(ompi_datatype_t* datatype, int keyval,
135143
{
136144
ucp_datatype_t ucp_datatype = (ucp_datatype_t)attr_val;
137145

146+
#ifdef HAVE_UCP_REQUEST_PARAM_T
147+
free((void*)datatype->pml_data);
148+
#else
138149
PML_UCX_ASSERT((uint64_t)ucp_datatype == datatype->pml_data);
139-
150+
#endif
140151
ucp_dt_destroy(ucp_datatype);
141152
datatype->pml_data = PML_UCX_DATATYPE_INVALID;
142153
return OMPI_SUCCESS;
143154
}
144155

156+
__opal_attribute_always_inline__
157+
static inline int mca_pml_ucx_datatype_is_contig(ompi_datatype_t *datatype)
158+
{
159+
ptrdiff_t lb;
160+
161+
ompi_datatype_type_lb(datatype, &lb);
162+
163+
return (datatype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) &&
164+
(datatype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) &&
165+
(lb == 0);
166+
}
167+
168+
#ifdef HAVE_UCP_REQUEST_PARAM_T
169+
__opal_attribute_always_inline__ static inline
170+
pml_ucx_datatype_t *mca_pml_ucx_init_nbx_datatype(ompi_datatype_t *datatype,
171+
ucp_datatype_t ucp_datatype,
172+
size_t size)
173+
{
174+
pml_ucx_datatype_t *pml_datatype;
175+
int is_contig_pow2;
176+
177+
pml_datatype = malloc(sizeof(*pml_datatype));
178+
if (pml_datatype == NULL) {
179+
PML_UCX_ERROR("Failed to allocate datatype structure");
180+
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
181+
}
182+
183+
pml_datatype->datatype = ucp_datatype;
184+
pml_datatype->op_param.send.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
185+
pml_datatype->op_param.send.cb.send = mca_pml_ucx_send_nbx_completion;
186+
pml_datatype->op_param.bsend.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK;
187+
pml_datatype->op_param.bsend.cb.send = mca_pml_ucx_bsend_nbx_completion;
188+
pml_datatype->op_param.recv.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
189+
UCP_OP_ATTR_FLAG_NO_IMM_CMPL;
190+
pml_datatype->op_param.recv.cb.recv = mca_pml_ucx_recv_nbx_completion;
191+
192+
is_contig_pow2 = mca_pml_ucx_datatype_is_contig(datatype) &&
193+
!(size & (size - 1)); /* is_pow2(size) */
194+
if (is_contig_pow2) {
195+
pml_datatype->size_shift = (int)(log(size) / log(2.0)); /* log2(size) */
196+
} else {
197+
pml_datatype->size_shift = 0;
198+
PML_UCX_DATATYPE_SET_VALUE(pml_datatype, op_attr_mask |= UCP_OP_ATTR_FIELD_DATATYPE);
199+
PML_UCX_DATATYPE_SET_VALUE(pml_datatype, datatype = ucp_datatype);
200+
}
201+
202+
return pml_datatype;
203+
}
204+
#endif
205+
145206
ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
146207
{
208+
size_t size = 0; /* init to suppress compiler warning */
147209
ucp_datatype_t ucp_datatype;
148210
ucs_status_t status;
149-
ptrdiff_t lb;
150-
size_t size;
151211
int ret;
152212

153-
ompi_datatype_type_lb(datatype, &lb);
154-
155-
if ((datatype->super.flags & OPAL_DATATYPE_FLAG_CONTIGUOUS) &&
156-
(datatype->super.flags & OPAL_DATATYPE_FLAG_NO_GAPS) &&
157-
(lb == 0))
158-
{
213+
if (mca_pml_ucx_datatype_is_contig(datatype)) {
159214
ompi_datatype_type_size(datatype, &size);
160-
datatype->pml_data = ucp_dt_make_contig(size);
161-
return datatype->pml_data;
215+
ucp_datatype = ucp_dt_make_contig(size);
216+
goto out;
162217
}
163218

164219
status = ucp_dt_create_generic(&pml_ucx_generic_datatype_ops,
@@ -168,8 +223,6 @@ ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
168223
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
169224
}
170225

171-
datatype->pml_data = ucp_datatype;
172-
173226
/* Add custom attribute, to clean up UCX resources when OMPI datatype is
174227
* released.
175228
*/
@@ -186,9 +239,18 @@ ucp_datatype_t mca_pml_ucx_init_datatype(ompi_datatype_t *datatype)
186239
ompi_mpi_abort(&ompi_mpi_comm_world.comm, 1);
187240
}
188241
}
189-
242+
out:
190243
PML_UCX_VERBOSE(7, "created generic UCX datatype 0x%"PRIx64, ucp_datatype)
191244

245+
#ifdef HAVE_UCP_REQUEST_PARAM_T
246+
UCS_STATIC_ASSERT(sizeof(datatype->pml_data) >= sizeof(pml_ucx_datatype_t*));
247+
datatype->pml_data = (uint64_t)mca_pml_ucx_init_nbx_datatype(datatype,
248+
ucp_datatype,
249+
size);
250+
#else
251+
datatype->pml_data = ucp_datatype;
252+
#endif
253+
192254
return ucp_datatype;
193255
}
194256

0 commit comments

Comments
 (0)