32
32
#include "oshmem/proc/proc.h"
33
33
#include "oshmem/mca/spml/base/base.h"
34
34
#include "oshmem/mca/spml/base/spml_base_putreq.h"
35
+ #include "oshmem/mca/atomic/atomic.h"
35
36
#include "oshmem/runtime/runtime.h"
36
37
37
38
#include "oshmem/mca/spml/ucx/spml_ucx_component.h"
@@ -67,6 +68,7 @@ mca_spml_ucx_t mca_spml_ucx = {
67
68
.spml_rmkey_free = mca_spml_ucx_rmkey_free ,
68
69
.spml_rmkey_ptr = mca_spml_ucx_rmkey_ptr ,
69
70
.spml_memuse_hook = mca_spml_ucx_memuse_hook ,
71
+ .spml_put_all_nb = mca_spml_ucx_put_all_nb ,
70
72
.self = (void * )& mca_spml_ucx
71
73
},
72
74
@@ -439,8 +441,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
439
441
ucx_mkey -> mem_h = (ucp_mem_h )mem_seg -> context ;
440
442
}
441
443
442
- status = ucp_rkey_pack (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h ,
443
- & mkeys [0 ].u .data , & len );
444
+ status = ucp_rkey_pack (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h ,
445
+ & mkeys [0 ].u .data , & len );
444
446
if (UCS_OK != status ) {
445
447
goto error_unmap ;
446
448
}
@@ -477,8 +479,6 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
477
479
{
478
480
spml_ucx_mkey_t * ucx_mkey ;
479
481
map_segment_t * mem_seg ;
480
- int segno ;
481
- int my_pe = oshmem_my_proc_id ();
482
482
483
483
MCA_SPML_CALL (quiet (oshmem_ctx_default ));
484
484
if (!mkeys )
@@ -493,7 +493,7 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
493
493
if (OPAL_UNLIKELY (NULL == mem_seg )) {
494
494
return OSHMEM_ERROR ;
495
495
}
496
-
496
+
497
497
if (MAP_SEGMENT_ALLOC_UCX != mem_seg -> type ) {
498
498
ucp_mem_unmap (mca_spml_ucx .ucp_context , ucx_mkey -> mem_h );
499
499
}
@@ -545,17 +545,15 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx
545
545
opal_atomic_wmb ();
546
546
}
547
547
548
- int mca_spml_ucx_ctx_create (long options , shmem_ctx_t * ctx )
548
+ static int mca_spml_ucx_ctx_create_common (long options , mca_spml_ucx_ctx_t * * ucx_ctx_p )
549
549
{
550
- mca_spml_ucx_ctx_t * ucx_ctx ;
551
550
ucp_worker_params_t params ;
552
551
ucp_ep_params_t ep_params ;
553
552
size_t i , j , nprocs = oshmem_num_procs ();
554
553
ucs_status_t err ;
555
- int my_pe = oshmem_my_proc_id ();
556
- size_t len ;
557
554
spml_ucx_mkey_t * ucx_mkey ;
558
555
sshmem_mkey_t * mkey ;
556
+ mca_spml_ucx_ctx_t * ucx_ctx ;
559
557
int rc = OSHMEM_ERROR ;
560
558
561
559
ucx_ctx = malloc (sizeof (mca_spml_ucx_ctx_t ));
@@ -580,10 +578,6 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
580
578
goto error ;
581
579
}
582
580
583
- if (mca_spml_ucx .active_array .ctxs_count == 0 ) {
584
- opal_progress_register (spml_ucx_ctx_progress );
585
- }
586
-
587
581
for (i = 0 ; i < nprocs ; i ++ ) {
588
582
ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
589
583
ep_params .address = (ucp_address_t * )(mca_spml_ucx .remote_addrs_tbl [i ]);
@@ -609,11 +603,8 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
609
603
}
610
604
}
611
605
612
- SHMEM_MUTEX_LOCK (mca_spml_ucx .internal_mutex );
613
- _ctx_add (& mca_spml_ucx .active_array , ucx_ctx );
614
- SHMEM_MUTEX_UNLOCK (mca_spml_ucx .internal_mutex );
606
+ * ucx_ctx_p = ucx_ctx ;
615
607
616
- (* ctx ) = (shmem_ctx_t )ucx_ctx ;
617
608
return OSHMEM_SUCCESS ;
618
609
619
610
error2 :
@@ -634,6 +625,33 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
634
625
return rc ;
635
626
}
636
627
628
+ int mca_spml_ucx_ctx_create (long options , shmem_ctx_t * ctx )
629
+ {
630
+ mca_spml_ucx_ctx_t * ucx_ctx ;
631
+ int rc ;
632
+
633
+ /* Take a lock controlling context creation. AUX context may set specific
634
+ * UCX parameters affecting worker creation, which are not needed for
635
+ * regular contexts. */
636
+ pthread_mutex_lock (& mca_spml_ucx .ctx_create_mutex );
637
+ rc = mca_spml_ucx_ctx_create_common (options , & ucx_ctx );
638
+ pthread_mutex_unlock (& mca_spml_ucx .ctx_create_mutex );
639
+ if (rc != OSHMEM_SUCCESS ) {
640
+ return rc ;
641
+ }
642
+
643
+ if (mca_spml_ucx .active_array .ctxs_count == 0 ) {
644
+ opal_progress_register (spml_ucx_ctx_progress );
645
+ }
646
+
647
+ SHMEM_MUTEX_LOCK (mca_spml_ucx .internal_mutex );
648
+ _ctx_add (& mca_spml_ucx .active_array , ucx_ctx );
649
+ SHMEM_MUTEX_UNLOCK (mca_spml_ucx .internal_mutex );
650
+
651
+ (* ctx ) = (shmem_ctx_t )ucx_ctx ;
652
+ return OSHMEM_SUCCESS ;
653
+ }
654
+
637
655
void mca_spml_ucx_ctx_destroy (shmem_ctx_t ctx )
638
656
{
639
657
MCA_SPML_CALL (quiet (ctx ));
@@ -748,6 +766,15 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx)
748
766
oshmem_shmem_abort (-1 );
749
767
return ret ;
750
768
}
769
+
770
+ /* If put_all_nb op/s is/are being executed asynchronously, need to wait its
771
+ * completion as well. */
772
+ if (ctx == oshmem_ctx_default ) {
773
+ while (mca_spml_ucx .aux_refcnt ) {
774
+ opal_progress ();
775
+ }
776
+ }
777
+
751
778
return OSHMEM_SUCCESS ;
752
779
}
753
780
@@ -785,3 +812,99 @@ int mca_spml_ucx_send(void* buf,
785
812
786
813
return rc ;
787
814
}
815
+
816
+ /* this can be called with request==NULL in case of immediate completion */
817
+ static void mca_spml_ucx_put_all_complete_cb (void * request , ucs_status_t status )
818
+ {
819
+ if (mca_spml_ucx .async_progress && (-- mca_spml_ucx .aux_refcnt == 0 )) {
820
+ opal_event_evtimer_del (mca_spml_ucx .tick_event );
821
+ opal_progress_unregister (spml_ucx_progress_aux_ctx );
822
+ }
823
+
824
+ if (request != NULL ) {
825
+ ucp_request_free (request );
826
+ }
827
+ }
828
+
829
+ /* Should be called with AUX lock taken */
830
+ static int mca_spml_ucx_create_aux_ctx (void )
831
+ {
832
+ unsigned major = 0 ;
833
+ unsigned minor = 0 ;
834
+ unsigned rel_number = 0 ;
835
+ int rc ;
836
+ bool rand_dci_supp ;
837
+
838
+ ucp_get_version (& major , & minor , & rel_number );
839
+ rand_dci_supp = UCX_VERSION (major , minor , rel_number ) >= UCX_VERSION (1 , 6 , 0 );
840
+
841
+ if (rand_dci_supp ) {
842
+ pthread_mutex_lock (& mca_spml_ucx .ctx_create_mutex );
843
+ opal_setenv ("UCX_DC_MLX5_TX_POLICY" , "rand" , 0 , & environ );
844
+ }
845
+
846
+ rc = mca_spml_ucx_ctx_create_common (SHMEM_CTX_PRIVATE , & mca_spml_ucx .aux_ctx );
847
+
848
+ if (rand_dci_supp ) {
849
+ opal_unsetenv ("UCX_DC_MLX5_TX_POLICY" , & environ );
850
+ pthread_mutex_unlock (& mca_spml_ucx .ctx_create_mutex );
851
+ }
852
+
853
+ return rc ;
854
+ }
855
+
856
+ int mca_spml_ucx_put_all_nb (void * dest , const void * source , size_t size , long * counter )
857
+ {
858
+ int my_pe = oshmem_my_proc_id ();
859
+ long val = 1 ;
860
+ int peer , dst_pe , rc ;
861
+ shmem_ctx_t ctx ;
862
+ struct timeval tv ;
863
+ void * request ;
864
+
865
+ mca_spml_ucx_aux_lock ();
866
+ if (mca_spml_ucx .async_progress ) {
867
+ if (mca_spml_ucx .aux_ctx == NULL ) {
868
+ rc = mca_spml_ucx_create_aux_ctx ();
869
+ if (rc != OMPI_SUCCESS ) {
870
+ mca_spml_ucx_aux_unlock ();
871
+ oshmem_shmem_abort (-1 );
872
+ }
873
+ }
874
+
875
+ if (mca_spml_ucx .aux_refcnt ++ == 0 ) {
876
+ tv .tv_sec = 0 ;
877
+ tv .tv_usec = mca_spml_ucx .async_tick ;
878
+ opal_event_evtimer_add (mca_spml_ucx .tick_event , & tv );
879
+ opal_progress_register (spml_ucx_progress_aux_ctx );
880
+ }
881
+ ctx = (shmem_ctx_t )mca_spml_ucx .aux_ctx ;
882
+ } else {
883
+ ctx = oshmem_ctx_default ;
884
+ }
885
+
886
+ for (peer = 0 ; peer < oshmem_num_procs (); peer ++ ) {
887
+ dst_pe = (peer + my_pe ) % oshmem_group_all -> proc_count ;
888
+ rc = mca_spml_ucx_put_nb (ctx ,
889
+ (void * )((uintptr_t )dest + my_pe * size ),
890
+ size ,
891
+ (void * )((uintptr_t )source + dst_pe * size ),
892
+ dst_pe , NULL );
893
+ RUNTIME_CHECK_RC (rc );
894
+
895
+ mca_spml_ucx_fence (ctx );
896
+
897
+ rc = MCA_ATOMIC_CALL (add (ctx , (void * )counter , val , sizeof (val ), dst_pe ));
898
+ RUNTIME_CHECK_RC (rc );
899
+ }
900
+
901
+ request = ucp_worker_flush_nb (((mca_spml_ucx_ctx_t * )ctx )-> ucp_worker , 0 ,
902
+ mca_spml_ucx_put_all_complete_cb );
903
+ if (!UCS_PTR_IS_PTR (request )) {
904
+ mca_spml_ucx_put_all_complete_cb (NULL , UCS_PTR_STATUS (request ));
905
+ }
906
+
907
+ mca_spml_ucx_aux_unlock ();
908
+
909
+ return OSHMEM_SUCCESS ;
910
+ }
0 commit comments