@@ -363,10 +363,10 @@ int mca_pml_ucx_cleanup(void)
363
363
364
364
static ucp_ep_h mca_pml_ucx_add_proc_common (ompi_proc_t * proc )
365
365
{
366
+ size_t addrlen = 0 ;
366
367
ucp_ep_params_t ep_params ;
367
368
ucp_address_t * address ;
368
369
ucs_status_t status ;
369
- size_t addrlen ;
370
370
ucp_ep_h ep ;
371
371
int ret ;
372
372
@@ -418,6 +418,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
418
418
return OMPI_SUCCESS ;
419
419
}
420
420
421
+ __opal_attribute_always_inline__
421
422
static inline ucp_ep_h mca_pml_ucx_get_ep (ompi_communicator_t * comm , int rank )
422
423
{
423
424
ompi_proc_t * proc_peer = ompi_comm_peer_lookup (comm , rank );
@@ -539,17 +540,28 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
539
540
int src , int tag , struct ompi_communicator_t * comm ,
540
541
struct ompi_request_t * * request )
541
542
{
543
+ #if HAVE_DECL_UCP_TAG_RECV_NBX
544
+ pml_ucx_datatype_t * op_data = mca_pml_ucx_get_op_data (datatype );
545
+ ucp_request_param_t * param = & op_data -> op_param .recv ;
546
+ #endif
547
+
542
548
ucp_tag_t ucp_tag , ucp_tag_mask ;
543
549
ompi_request_t * req ;
544
550
545
551
PML_UCX_TRACE_RECV ("irecv request *%p" , buf , count , datatype , src , tag , comm ,
546
552
(void * )request );
547
553
548
554
PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
555
+ #if HAVE_DECL_UCP_TAG_RECV_NBX
556
+ req = (ompi_request_t * )ucp_tag_recv_nbx (ompi_pml_ucx .ucp_worker , buf ,
557
+ mca_pml_ucx_get_data_size (op_data , count ),
558
+ ucp_tag , ucp_tag_mask , param );
559
+ #else
549
560
req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker , buf , count ,
550
561
mca_pml_ucx_get_datatype (datatype ),
551
562
ucp_tag , ucp_tag_mask ,
552
563
mca_pml_ucx_recv_completion );
564
+ #endif
553
565
if (UCS_PTR_IS_ERR (req )) {
554
566
PML_UCX_ERROR ("ucx recv failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
555
567
return OMPI_ERROR ;
@@ -565,20 +577,34 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
565
577
int tag , struct ompi_communicator_t * comm ,
566
578
ompi_status_public_t * mpi_status )
567
579
{
580
+ /* coverity[bad_alloc_arithmetic] */
581
+ void * req = PML_UCX_REQ_ALLOCA ();
582
+ #if HAVE_DECL_UCP_TAG_RECV_NBX
583
+ pml_ucx_datatype_t * op_data = mca_pml_ucx_get_op_data (datatype );
584
+ ucp_request_param_t * recv_param = & op_data -> op_param .recv ;
585
+ ucp_request_param_t param ;
586
+
587
+ param .op_attr_mask = UCP_OP_ATTR_FIELD_REQUEST |
588
+ (recv_param -> op_attr_mask & UCP_OP_ATTR_FIELD_DATATYPE );
589
+ param .datatype = recv_param -> datatype ;
590
+ param .request = req ;
591
+ #endif
568
592
ucp_tag_t ucp_tag , ucp_tag_mask ;
569
593
ucp_tag_recv_info_t info ;
570
594
ucs_status_t status ;
571
- void * req ;
572
595
573
596
PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
574
597
575
- /* coverity[bad_alloc_arithmetic] */
576
598
PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
577
- req = PML_UCX_REQ_ALLOCA ();
578
- status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
579
- mca_pml_ucx_get_datatype (datatype ),
580
- ucp_tag , ucp_tag_mask , req );
581
-
599
+ #if HAVE_DECL_UCP_TAG_RECV_NBX
600
+ ucp_tag_recv_nbx (ompi_pml_ucx .ucp_worker , buf ,
601
+ mca_pml_ucx_get_data_size (op_data , count ),
602
+ ucp_tag , ucp_tag_mask , & param );
603
+ #else
604
+ ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
605
+ mca_pml_ucx_get_datatype (datatype ),
606
+ ucp_tag , ucp_tag_mask , req );
607
+ #endif
582
608
MCA_COMMON_UCX_PROGRESS_LOOP (ompi_pml_ucx .ucp_worker ) {
583
609
status = ucp_request_test (req , & info );
584
610
if (status != UCS_INPROGRESS ) {
@@ -588,6 +614,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
588
614
}
589
615
}
590
616
617
+ __opal_attribute_always_inline__
591
618
static inline const char * mca_pml_ucx_send_mode_name (mca_pml_base_send_mode_t mode )
592
619
{
593
620
switch (mode ) {
@@ -709,6 +736,7 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
709
736
return NULL ;
710
737
}
711
738
739
+ __opal_attribute_always_inline__
712
740
static inline ucs_status_ptr_t mca_pml_ucx_common_send (ucp_ep_h ep , const void * buf ,
713
741
size_t count ,
714
742
ompi_datatype_t * datatype ,
@@ -726,6 +754,32 @@ static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *
726
754
}
727
755
}
728
756
757
+ #if HAVE_DECL_UCP_TAG_SEND_NBX
758
+ __opal_attribute_always_inline__
759
+ static inline ucs_status_ptr_t
760
+ mca_pml_ucx_common_send_nbx (ucp_ep_h ep , const void * buf ,
761
+ size_t count ,
762
+ ompi_datatype_t * datatype ,
763
+ ucp_tag_t tag ,
764
+ mca_pml_base_send_mode_t mode ,
765
+ ucp_request_param_t * param )
766
+ {
767
+ pml_ucx_datatype_t * op_data = mca_pml_ucx_get_op_data (datatype );
768
+
769
+ if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_BUFFERED == mode )) {
770
+ return mca_pml_ucx_bsend (ep , buf , count , datatype , tag );
771
+ } else if (OPAL_UNLIKELY (MCA_PML_BASE_SEND_SYNCHRONOUS == mode )) {
772
+ return ucp_tag_send_sync_nb (ep , buf , count ,
773
+ mca_pml_ucx_get_datatype (datatype ), tag ,
774
+ (ucp_send_callback_t )param -> cb .send );
775
+ } else {
776
+ return ucp_tag_send_nbx (ep , buf ,
777
+ mca_pml_ucx_get_data_size (op_data , count ),
778
+ tag , param );
779
+ }
780
+ }
781
+ #endif
782
+
729
783
int mca_pml_ucx_isend (const void * buf , size_t count , ompi_datatype_t * datatype ,
730
784
int dst , int tag , mca_pml_base_send_mode_t mode ,
731
785
struct ompi_communicator_t * comm ,
@@ -744,10 +798,16 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
744
798
return OMPI_ERROR ;
745
799
}
746
800
801
+ #if HAVE_DECL_UCP_TAG_SEND_NBX
802
+ req = (ompi_request_t * )mca_pml_ucx_common_send_nbx (ep , buf , count , datatype ,
803
+ PML_UCX_MAKE_SEND_TAG (tag , comm ), mode ,
804
+ & mca_pml_ucx_get_op_data (datatype )-> op_param .send );
805
+ #else
747
806
req = (ompi_request_t * )mca_pml_ucx_common_send (ep , buf , count , datatype ,
748
807
mca_pml_ucx_get_datatype (datatype ),
749
808
PML_UCX_MAKE_SEND_TAG (tag , comm ), mode ,
750
809
mca_pml_ucx_send_completion );
810
+ #endif
751
811
752
812
if (req == NULL ) {
753
813
PML_UCX_VERBOSE (8 , "returning completed request" );
@@ -789,20 +849,40 @@ mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count,
789
849
#if HAVE_DECL_UCP_TAG_SEND_NBR
790
850
static inline __opal_attribute_always_inline__ int
791
851
mca_pml_ucx_send_nbr (ucp_ep_h ep , const void * buf , size_t count ,
792
- ucp_datatype_t ucx_datatype , ucp_tag_t tag )
793
-
852
+ ompi_datatype_t * datatype , ucp_tag_t tag )
794
853
{
795
- ucs_status_ptr_t req ;
796
- ucs_status_t status ;
797
-
798
854
/* coverity[bad_alloc_arithmetic] */
799
- req = PML_UCX_REQ_ALLOCA ();
800
- status = ucp_tag_send_nbr (ep , buf , count , ucx_datatype , tag , req );
855
+ ucs_status_ptr_t req = PML_UCX_REQ_ALLOCA ();
856
+ #if HAVE_DECL_UCP_TAG_SEND_NBX
857
+ pml_ucx_datatype_t * op_data = mca_pml_ucx_get_op_data (datatype );
858
+ ucp_request_param_t param = {
859
+ .op_attr_mask = UCP_OP_ATTR_FIELD_REQUEST |
860
+ (op_data -> op_param .send .op_attr_mask & UCP_OP_ATTR_FIELD_DATATYPE ) |
861
+ UCP_OP_ATTR_FLAG_FAST_CMPL ,
862
+ .datatype = op_data -> op_param .send .datatype ,
863
+ .request = req
864
+ };
865
+
866
+ req = ucp_tag_send_nbx (ep , buf ,
867
+ mca_pml_ucx_get_data_size (op_data , count ),
868
+ tag , & param );
869
+ if (OPAL_LIKELY (req == UCS_OK )) {
870
+ return OMPI_SUCCESS ;
871
+ } else if (UCS_PTR_IS_ERR (req )) {
872
+ PML_UCX_ERROR ("%s failed: %d, %s" , __func__ , UCS_PTR_STATUS (req ),
873
+ ucs_status_string (UCS_PTR_STATUS (req )));
874
+ return OPAL_ERROR ;
875
+ }
876
+ #else
877
+ ucs_status_t status ;
878
+ status = ucp_tag_send_nbr (ep , buf , count ,
879
+ mca_pml_ucx_get_datatype (datatype ), tag , req );
801
880
if (OPAL_LIKELY (status == UCS_OK )) {
802
881
return OMPI_SUCCESS ;
803
882
}
883
+ #endif
804
884
805
- MCA_COMMON_UCX_WAIT_LOOP (req , ompi_pml_ucx .ucp_worker , "ucx send" , (void )0 );
885
+ MCA_COMMON_UCX_WAIT_LOOP (req , ompi_pml_ucx .ucp_worker , "ucx send nbr " , (void )0 );
806
886
}
807
887
#endif
808
888
@@ -823,8 +903,7 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
823
903
#if HAVE_DECL_UCP_TAG_SEND_NBR
824
904
if (OPAL_LIKELY ((MCA_PML_BASE_SEND_BUFFERED != mode ) &&
825
905
(MCA_PML_BASE_SEND_SYNCHRONOUS != mode ))) {
826
- return mca_pml_ucx_send_nbr (ep , buf , count ,
827
- mca_pml_ucx_get_datatype (datatype ),
906
+ return mca_pml_ucx_send_nbr (ep , buf , count , datatype ,
828
907
PML_UCX_MAKE_SEND_TAG (tag , comm ));
829
908
}
830
909
#endif
0 commit comments