Skip to content

Commit 5d2200a

Browse files
authored
Merge pull request #6605 from brminich/topic/shmem_all2all_put
SPML/UCX: Add shmemx_alltoall_global_nb routine to shmemx.h
2 parents 399b713 + d4843b1 commit 5d2200a

File tree

11 files changed

+350
-36
lines changed

11 files changed

+350
-36
lines changed

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
#include "osc_ucx.h"
2121
#include "osc_ucx_request.h"
2222

23-
#define UCX_VERSION(_major, _minor, _build) (((_major) * 100) + (_minor))
24-
2523
#define memcpy_off(_dst, _src, _len, _off) \
2624
memcpy(((char*)(_dst)) + (_off), _src, _len); \
2725
(_off) += (_len);

opal/mca/common/ucx/common_ucx.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ BEGIN_C_DECLS
3939
#define MCA_COMMON_UCX_PER_TARGET_OPS_THRESHOLD 1000
4040
#define MCA_COMMON_UCX_GLOBAL_OPS_THRESHOLD 1000
4141

42+
#define UCX_VERSION(_major, _minor, _build) (((_major) * 100) + (_minor))
43+
44+
4245
#define _MCA_COMMON_UCX_QUOTE(_x) \
4346
# _x
4447
#define MCA_COMMON_UCX_QUOTE(_x) \

oshmem/include/shmemx.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ OSHMEM_DECLSPEC void shmemx_int16_prod_to_all(int16_t *target, const int16_t *so
168168
OSHMEM_DECLSPEC void shmemx_int32_prod_to_all(int32_t *target, const int32_t *source, int nreduce, int PE_start, int logPE_stride, int PE_size, int32_t *pWrk, long *pSync);
169169
OSHMEM_DECLSPEC void shmemx_int64_prod_to_all(int64_t *target, const int64_t *source, int nreduce, int PE_start, int logPE_stride, int PE_size, int64_t *pWrk, long *pSync);
170170

171+
/* shmemx_alltoall_global_nb is a nonblocking collective routine, where each PE
172+
* exchanges “size” bytes of data with all other PEs in the OpenSHMEM job.
173+
174+
* @param dest A symmetric data object that is large enough to receive
175+
* “size” bytes of data from each PE in the OpenSHMEM job.
176+
* @param source A symmetric data object that contains “size” bytes of data
177+
* for each PE in the OpenSHMEM job.
178+
* @param size The number of bytes to be sent to each PE in the job.
179+
* @param counter A symmetric data object to be atomically incremented after
180+
* the target buffer is updated.
181+
*
182+
* @return OSHMEM_SUCCESS or failure status.
183+
*/
184+
OSHMEM_DECLSPEC void shmemx_alltoall_global_nb(void *dest, const void *source, size_t size, long *counter);
185+
171186
/*
172187
* Backward compatibility section
173188
*/

oshmem/mca/spml/base/base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ OSHMEM_DECLSPEC int mca_spml_base_get_nb(void *dst_addr,
9393
void **handle);
9494

9595
OSHMEM_DECLSPEC void mca_spml_base_memuse_hook(void *addr, size_t length);
96+
97+
OSHMEM_DECLSPEC int mca_spml_base_put_all_nb(void *target, const void *source,
98+
size_t size, long *counter);
99+
96100
/*
97101
* MCA framework
98102
*/

oshmem/mca/spml/base/spml_base.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,9 @@ int mca_spml_base_get_nb(void *dst_addr, size_t size,
280280
void mca_spml_base_memuse_hook(void *addr, size_t length)
281281
{
282282
}
283+
284+
int mca_spml_base_put_all_nb(void *target, const void *source,
285+
size_t size, long *counter)
286+
{
287+
return OSHMEM_ERR_NOT_IMPLEMENTED;
288+
}

oshmem/mca/spml/ikrit/spml_ikrit.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ mca_spml_ikrit_t mca_spml_ikrit = {
179179
mca_spml_base_rmkey_free,
180180
mca_spml_base_rmkey_ptr,
181181
mca_spml_base_memuse_hook,
182+
mca_spml_base_put_all_nb,
182183

183184
(void*)&mca_spml_ikrit
184185
},

oshmem/mca/spml/spml.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,35 @@ typedef int (*mca_spml_base_module_send_fn_t)(void *buf,
314314
int dst,
315315
mca_spml_base_put_mode_t mode);
316316

317+
/**
318+
* The routine transfers the data asynchronously from the source PE to all
319+
* PEs in the OpenSHMEM job. The routine returns immediately. The source and
320+
* target buffers are reusable only after the completion of the routine.
321+
* After the data is transferred to the target buffers, the counter object
322+
* is updated atomically. The counter object can be read either using atomic
323+
* operations such as shmem_atomic_fetch or can use point-to-point synchronization
324+
* routines such as shmem_wait_until and shmem_test.
325+
*
326+
* Shmem_quiet may be used for completing the operation, but not required for
327+
* progress or completion. In a multithreaded OpenSHMEM program, the user
328+
* (the OpenSHMEM program) should ensure the correct ordering of
329+
* shmemx_alltoall_global calls.
330+
*
331+
* @param dest A symmetric data object that is large enough to receive
332+
* “size” bytes of data from each PE in the OpenSHMEM job.
333+
* @param source A symmetric data object that contains “size” bytes of data
334+
* for each PE in the OpenSHMEM job.
335+
* @param size The number of bytes to be sent to each PE in the job.
336+
* @param counter A symmetric data object to be atomically incremented after
337+
* the target buffer is updated.
338+
*
339+
* @return OSHMEM_SUCCESS or failure status.
340+
*/
341+
typedef int (*mca_spml_base_module_put_all_nb_fn_t)(void *dest,
342+
const void *source,
343+
size_t size,
344+
long *counter);
345+
317346
/**
318347
* Assures ordering of delivery of put() requests
319348
*
@@ -381,6 +410,7 @@ struct mca_spml_base_module_1_0_0_t {
381410
mca_spml_base_module_mkey_ptr_fn_t spml_rmkey_ptr;
382411

383412
mca_spml_base_module_memuse_hook_fn_t spml_memuse_hook;
413+
mca_spml_base_module_put_all_nb_fn_t spml_put_all_nb;
384414
void *self;
385415
};
386416

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "oshmem/proc/proc.h"
3333
#include "oshmem/mca/spml/base/base.h"
3434
#include "oshmem/mca/spml/base/spml_base_putreq.h"
35+
#include "oshmem/mca/atomic/atomic.h"
3536
#include "oshmem/runtime/runtime.h"
3637

3738
#include "oshmem/mca/spml/ucx/spml_ucx_component.h"
@@ -67,6 +68,7 @@ mca_spml_ucx_t mca_spml_ucx = {
6768
.spml_rmkey_free = mca_spml_ucx_rmkey_free,
6869
.spml_rmkey_ptr = mca_spml_ucx_rmkey_ptr,
6970
.spml_memuse_hook = mca_spml_ucx_memuse_hook,
71+
.spml_put_all_nb = mca_spml_ucx_put_all_nb,
7072
.self = (void*)&mca_spml_ucx
7173
},
7274

@@ -439,8 +441,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
439441
ucx_mkey->mem_h = (ucp_mem_h)mem_seg->context;
440442
}
441443

442-
status = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h,
443-
&mkeys[0].u.data, &len);
444+
status = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h,
445+
&mkeys[0].u.data, &len);
444446
if (UCS_OK != status) {
445447
goto error_unmap;
446448
}
@@ -477,8 +479,6 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
477479
{
478480
spml_ucx_mkey_t *ucx_mkey;
479481
map_segment_t *mem_seg;
480-
int segno;
481-
int my_pe = oshmem_my_proc_id();
482482

483483
MCA_SPML_CALL(quiet(oshmem_ctx_default));
484484
if (!mkeys)
@@ -493,7 +493,7 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
493493
if (OPAL_UNLIKELY(NULL == mem_seg)) {
494494
return OSHMEM_ERROR;
495495
}
496-
496+
497497
if (MAP_SEGMENT_ALLOC_UCX != mem_seg->type) {
498498
ucp_mem_unmap(mca_spml_ucx.ucp_context, ucx_mkey->mem_h);
499499
}
@@ -545,17 +545,15 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx
545545
opal_atomic_wmb ();
546546
}
547547

548-
int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
548+
static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx_ctx_p)
549549
{
550-
mca_spml_ucx_ctx_t *ucx_ctx;
551550
ucp_worker_params_t params;
552551
ucp_ep_params_t ep_params;
553552
size_t i, j, nprocs = oshmem_num_procs();
554553
ucs_status_t err;
555-
int my_pe = oshmem_my_proc_id();
556-
size_t len;
557554
spml_ucx_mkey_t *ucx_mkey;
558555
sshmem_mkey_t *mkey;
556+
mca_spml_ucx_ctx_t *ucx_ctx;
559557
int rc = OSHMEM_ERROR;
560558

561559
ucx_ctx = malloc(sizeof(mca_spml_ucx_ctx_t));
@@ -580,10 +578,6 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
580578
goto error;
581579
}
582580

583-
if (mca_spml_ucx.active_array.ctxs_count == 0) {
584-
opal_progress_register(spml_ucx_ctx_progress);
585-
}
586-
587581
for (i = 0; i < nprocs; i++) {
588582
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
589583
ep_params.address = (ucp_address_t *)(mca_spml_ucx.remote_addrs_tbl[i]);
@@ -609,11 +603,8 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
609603
}
610604
}
611605

612-
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
613-
_ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
614-
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
606+
*ucx_ctx_p = ucx_ctx;
615607

616-
(*ctx) = (shmem_ctx_t)ucx_ctx;
617608
return OSHMEM_SUCCESS;
618609

619610
error2:
@@ -634,6 +625,33 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
634625
return rc;
635626
}
636627

628+
int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
629+
{
630+
mca_spml_ucx_ctx_t *ucx_ctx;
631+
int rc;
632+
633+
/* Take a lock controlling context creation. AUX context may set specific
634+
* UCX parameters affecting worker creation, which are not needed for
635+
* regular contexts. */
636+
pthread_mutex_lock(&mca_spml_ucx.ctx_create_mutex);
637+
rc = mca_spml_ucx_ctx_create_common(options, &ucx_ctx);
638+
pthread_mutex_unlock(&mca_spml_ucx.ctx_create_mutex);
639+
if (rc != OSHMEM_SUCCESS) {
640+
return rc;
641+
}
642+
643+
if (mca_spml_ucx.active_array.ctxs_count == 0) {
644+
opal_progress_register(spml_ucx_ctx_progress);
645+
}
646+
647+
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
648+
_ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
649+
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
650+
651+
(*ctx) = (shmem_ctx_t)ucx_ctx;
652+
return OSHMEM_SUCCESS;
653+
}
654+
637655
void mca_spml_ucx_ctx_destroy(shmem_ctx_t ctx)
638656
{
639657
MCA_SPML_CALL(quiet(ctx));
@@ -748,6 +766,15 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx)
748766
oshmem_shmem_abort(-1);
749767
return ret;
750768
}
769+
770+
/* If put_all_nb op/s is/are being executed asynchronously, need to wait its
771+
* completion as well. */
772+
if (ctx == oshmem_ctx_default) {
773+
while (mca_spml_ucx.aux_refcnt) {
774+
opal_progress();
775+
}
776+
}
777+
751778
return OSHMEM_SUCCESS;
752779
}
753780

@@ -785,3 +812,99 @@ int mca_spml_ucx_send(void* buf,
785812

786813
return rc;
787814
}
815+
816+
/* this can be called with request==NULL in case of immediate completion */
817+
static void mca_spml_ucx_put_all_complete_cb(void *request, ucs_status_t status)
818+
{
819+
if (mca_spml_ucx.async_progress && (--mca_spml_ucx.aux_refcnt == 0)) {
820+
opal_event_evtimer_del(mca_spml_ucx.tick_event);
821+
opal_progress_unregister(spml_ucx_progress_aux_ctx);
822+
}
823+
824+
if (request != NULL) {
825+
ucp_request_free(request);
826+
}
827+
}
828+
829+
/* Should be called with AUX lock taken */
830+
static int mca_spml_ucx_create_aux_ctx(void)
831+
{
832+
unsigned major = 0;
833+
unsigned minor = 0;
834+
unsigned rel_number = 0;
835+
int rc;
836+
bool rand_dci_supp;
837+
838+
ucp_get_version(&major, &minor, &rel_number);
839+
rand_dci_supp = UCX_VERSION(major, minor, rel_number) >= UCX_VERSION(1, 6, 0);
840+
841+
if (rand_dci_supp) {
842+
pthread_mutex_lock(&mca_spml_ucx.ctx_create_mutex);
843+
opal_setenv("UCX_DC_MLX5_TX_POLICY", "rand", 0, &environ);
844+
}
845+
846+
rc = mca_spml_ucx_ctx_create_common(SHMEM_CTX_PRIVATE, &mca_spml_ucx.aux_ctx);
847+
848+
if (rand_dci_supp) {
849+
opal_unsetenv("UCX_DC_MLX5_TX_POLICY", &environ);
850+
pthread_mutex_unlock(&mca_spml_ucx.ctx_create_mutex);
851+
}
852+
853+
return rc;
854+
}
855+
856+
int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *counter)
857+
{
858+
int my_pe = oshmem_my_proc_id();
859+
long val = 1;
860+
int peer, dst_pe, rc;
861+
shmem_ctx_t ctx;
862+
struct timeval tv;
863+
void *request;
864+
865+
mca_spml_ucx_aux_lock();
866+
if (mca_spml_ucx.async_progress) {
867+
if (mca_spml_ucx.aux_ctx == NULL) {
868+
rc = mca_spml_ucx_create_aux_ctx();
869+
if (rc != OMPI_SUCCESS) {
870+
mca_spml_ucx_aux_unlock();
871+
oshmem_shmem_abort(-1);
872+
}
873+
}
874+
875+
if (mca_spml_ucx.aux_refcnt++ == 0) {
876+
tv.tv_sec = 0;
877+
tv.tv_usec = mca_spml_ucx.async_tick;
878+
opal_event_evtimer_add(mca_spml_ucx.tick_event, &tv);
879+
opal_progress_register(spml_ucx_progress_aux_ctx);
880+
}
881+
ctx = (shmem_ctx_t)mca_spml_ucx.aux_ctx;
882+
} else {
883+
ctx = oshmem_ctx_default;
884+
}
885+
886+
for (peer = 0; peer < oshmem_num_procs(); peer++) {
887+
dst_pe = (peer + my_pe) % oshmem_group_all->proc_count;
888+
rc = mca_spml_ucx_put_nb(ctx,
889+
(void*)((uintptr_t)dest + my_pe * size),
890+
size,
891+
(void*)((uintptr_t)source + dst_pe * size),
892+
dst_pe, NULL);
893+
RUNTIME_CHECK_RC(rc);
894+
895+
mca_spml_ucx_fence(ctx);
896+
897+
rc = MCA_ATOMIC_CALL(add(ctx, (void*)counter, val, sizeof(val), dst_pe));
898+
RUNTIME_CHECK_RC(rc);
899+
}
900+
901+
request = ucp_worker_flush_nb(((mca_spml_ucx_ctx_t*)ctx)->ucp_worker, 0,
902+
mca_spml_ucx_put_all_complete_cb);
903+
if (!UCS_PTR_IS_PTR(request)) {
904+
mca_spml_ucx_put_all_complete_cb(NULL, UCS_PTR_STATUS(request));
905+
}
906+
907+
mca_spml_ucx_aux_unlock();
908+
909+
return OSHMEM_SUCCESS;
910+
}

0 commit comments

Comments
 (0)