Skip to content

Commit 2ef5bd8

Browse files
committed
SPML/UCX: Add shmemx_alltoall_global_nb routine to shmemx.h
The new routine transfers the data asynchronously from the source PE to all PEs in the OpenSHMEM job. The routine returns immediately. The source and target buffers are reusable only after the completion of the routine. After the data is transferred to the target buffers, the counter object is updated atomically. The counter object can be read either using atomic operations such as shmem_atomic_fetch or can use point-to-point synchronization routines such as shmem_wait_until and shmem_test. Signed-off-by: Mikhail Brinskii <mikhailb@mellanox.com>
1 parent 61d6770 commit 2ef5bd8

File tree

11 files changed

+330
-36
lines changed

11 files changed

+330
-36
lines changed

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
#include "osc_ucx.h"
2121
#include "osc_ucx_request.h"
2222

23-
#define UCX_VERSION(_major, _minor, _build) (((_major) * 100) + (_minor))
24-
2523
#define memcpy_off(_dst, _src, _len, _off) \
2624
memcpy(((char*)(_dst)) + (_off), _src, _len); \
2725
(_off) += (_len);

opal/mca/common/ucx/common_ucx.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ BEGIN_C_DECLS
3939
#define MCA_COMMON_UCX_PER_TARGET_OPS_THRESHOLD 1000
4040
#define MCA_COMMON_UCX_GLOBAL_OPS_THRESHOLD 1000
4141

42+
#define UCX_VERSION(_major, _minor, _build) (((_major) * 100) + (_minor))
43+
44+
4245
#define _MCA_COMMON_UCX_QUOTE(_x) \
4346
# _x
4447
#define MCA_COMMON_UCX_QUOTE(_x) \

oshmem/include/shmemx.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ OSHMEM_DECLSPEC void shmemx_int16_prod_to_all(int16_t *target, const int16_t *so
168168
OSHMEM_DECLSPEC void shmemx_int32_prod_to_all(int32_t *target, const int32_t *source, int nreduce, int PE_start, int logPE_stride, int PE_size, int32_t *pWrk, long *pSync);
169169
OSHMEM_DECLSPEC void shmemx_int64_prod_to_all(int64_t *target, const int64_t *source, int nreduce, int PE_start, int logPE_stride, int PE_size, int64_t *pWrk, long *pSync);
170170

171+
/* Alltoall put with atomic counter increase */
172+
OSHMEM_DECLSPEC void shmemx_put_with_long_inc_all(void *target, const void *source, size_t size, long *counter);
173+
171174
/*
172175
* Backward compatibility section
173176
*/

oshmem/mca/spml/base/base.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ OSHMEM_DECLSPEC int mca_spml_base_get_nb(void *dst_addr,
9393
void **handle);
9494

9595
OSHMEM_DECLSPEC void mca_spml_base_memuse_hook(void *addr, size_t length);
96+
97+
OSHMEM_DECLSPEC int mca_spml_base_put_all_nb(void *target, const void *source,
98+
size_t size, long *counter);
99+
96100
/*
97101
* MCA framework
98102
*/

oshmem/mca/spml/base/spml_base.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,9 @@ int mca_spml_base_get_nb(void *dst_addr, size_t size,
280280
void mca_spml_base_memuse_hook(void *addr, size_t length)
281281
{
282282
}
283+
284+
int mca_spml_base_put_all_nb(void *target, const void *source,
285+
size_t size, long *counter)
286+
{
287+
return OSHMEM_ERR_NOT_IMPLEMENTED;
288+
}

oshmem/mca/spml/ikrit/spml_ikrit.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ mca_spml_ikrit_t mca_spml_ikrit = {
179179
mca_spml_base_rmkey_free,
180180
mca_spml_base_rmkey_ptr,
181181
mca_spml_base_memuse_hook,
182+
mca_spml_base_put_all_nb,
182183

183184
(void*)&mca_spml_ikrit
184185
},

oshmem/mca/spml/spml.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,35 @@ typedef int (*mca_spml_base_module_send_fn_t)(void *buf,
314314
int dst,
315315
mca_spml_base_put_mode_t mode);
316316

317+
/**
318+
* The routine transfers the data asynchronously from the source PE to all
319+
* PEs in the OpenSHMEM job. The routine returns immediately. The source and
320+
* target buffers are reusable only after the completion of the routine.
321+
* After the data is transferred to the target buffers, the counter object
322+
* is updated atomically. The counter object can be read either using atomic
323+
* operations such as shmem_atomic_fetch or can use point-to-point synchronization
324+
* routines such as shmem_wait_until and shmem_test.
325+
*
326+
* Shmem_quiet may be used for completing the operation, but not required for
327+
* progress or completion. In a multithreaded OpenSHMEM program, the user
328+
* (the OpenSHMEM program) should ensure the correct ordering of
329+
* shmemx_alltoall_global calls.
330+
*
331+
* @param dest A symmetric data object that is large enough to receive
332+
* “size” bytes of data.
333+
* @param source A symmetric data object that contains “size” bytes of data
334+
* for each PE in the OpenSHMEM job.
335+
* @param size The number of bytes to be sent to each PE in the job.
336+
* @param counter A symmetric data object to be atomically incremented after
337+
* the target buffer is updated.
338+
*
339+
* @return OSHMEM_SUCCESS or failure status.
340+
*/
341+
typedef int (*mca_spml_base_module_put_all_nb_fn_t)(void *dest,
342+
const void *source,
343+
size_t size,
344+
long *counter);
345+
317346
/**
318347
* Assures ordering of delivery of put() requests
319348
*
@@ -381,6 +410,7 @@ struct mca_spml_base_module_1_0_0_t {
381410
mca_spml_base_module_mkey_ptr_fn_t spml_rmkey_ptr;
382411

383412
mca_spml_base_module_memuse_hook_fn_t spml_memuse_hook;
413+
mca_spml_base_module_put_all_nb_fn_t spml_put_all_nb;
384414
void *self;
385415
};
386416

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "oshmem/proc/proc.h"
3333
#include "oshmem/mca/spml/base/base.h"
3434
#include "oshmem/mca/spml/base/spml_base_putreq.h"
35+
#include "oshmem/mca/atomic/atomic.h"
3536
#include "oshmem/runtime/runtime.h"
3637

3738
#include "oshmem/mca/spml/ucx/spml_ucx_component.h"
@@ -67,6 +68,7 @@ mca_spml_ucx_t mca_spml_ucx = {
6768
.spml_rmkey_free = mca_spml_ucx_rmkey_free,
6869
.spml_rmkey_ptr = mca_spml_ucx_rmkey_ptr,
6970
.spml_memuse_hook = mca_spml_ucx_memuse_hook,
71+
.spml_put_all_nb = mca_spml_ucx_put_all_nb,
7072
.self = (void*)&mca_spml_ucx
7173
},
7274

@@ -439,8 +441,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
439441
ucx_mkey->mem_h = (ucp_mem_h)mem_seg->context;
440442
}
441443

442-
status = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h,
443-
&mkeys[0].u.data, &len);
444+
status = ucp_rkey_pack(mca_spml_ucx.ucp_context, ucx_mkey->mem_h,
445+
&mkeys[0].u.data, &len);
444446
if (UCS_OK != status) {
445447
goto error_unmap;
446448
}
@@ -477,8 +479,6 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
477479
{
478480
spml_ucx_mkey_t *ucx_mkey;
479481
map_segment_t *mem_seg;
480-
int segno;
481-
int my_pe = oshmem_my_proc_id();
482482

483483
MCA_SPML_CALL(quiet(oshmem_ctx_default));
484484
if (!mkeys)
@@ -493,7 +493,7 @@ int mca_spml_ucx_deregister(sshmem_mkey_t *mkeys)
493493
if (OPAL_UNLIKELY(NULL == mem_seg)) {
494494
return OSHMEM_ERROR;
495495
}
496-
496+
497497
if (MAP_SEGMENT_ALLOC_UCX != mem_seg->type) {
498498
ucp_mem_unmap(mca_spml_ucx.ucp_context, ucx_mkey->mem_h);
499499
}
@@ -545,17 +545,15 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx
545545
opal_atomic_wmb ();
546546
}
547547

548-
int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
548+
static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx_ctx_p)
549549
{
550-
mca_spml_ucx_ctx_t *ucx_ctx;
551550
ucp_worker_params_t params;
552551
ucp_ep_params_t ep_params;
553552
size_t i, j, nprocs = oshmem_num_procs();
554553
ucs_status_t err;
555-
int my_pe = oshmem_my_proc_id();
556-
size_t len;
557554
spml_ucx_mkey_t *ucx_mkey;
558555
sshmem_mkey_t *mkey;
556+
mca_spml_ucx_ctx_t *ucx_ctx;
559557
int rc = OSHMEM_ERROR;
560558

561559
ucx_ctx = malloc(sizeof(mca_spml_ucx_ctx_t));
@@ -580,10 +578,6 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
580578
goto error;
581579
}
582580

583-
if (mca_spml_ucx.active_array.ctxs_count == 0) {
584-
opal_progress_register(spml_ucx_ctx_progress);
585-
}
586-
587581
for (i = 0; i < nprocs; i++) {
588582
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
589583
ep_params.address = (ucp_address_t *)(mca_spml_ucx.remote_addrs_tbl[i]);
@@ -609,11 +603,8 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
609603
}
610604
}
611605

612-
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
613-
_ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
614-
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
606+
*ucx_ctx_p = ucx_ctx;
615607

616-
(*ctx) = (shmem_ctx_t)ucx_ctx;
617608
return OSHMEM_SUCCESS;
618609

619610
error2:
@@ -634,6 +625,33 @@ int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
634625
return rc;
635626
}
636627

628+
int mca_spml_ucx_ctx_create(long options, shmem_ctx_t *ctx)
629+
{
630+
mca_spml_ucx_ctx_t *ucx_ctx;
631+
int rc;
632+
633+
/* Take a lock controlling aux context. AUX context may set specific
634+
* UCX parameters affecting worker creation, which are not needed for
635+
* regular contexts. */
636+
mca_spml_ucx_aux_lock();
637+
rc = mca_spml_ucx_ctx_create_common(options, &ucx_ctx);
638+
mca_spml_ucx_aux_unlock();
639+
if (rc != OSHMEM_SUCCESS) {
640+
return rc;
641+
}
642+
643+
if (mca_spml_ucx.active_array.ctxs_count == 0) {
644+
opal_progress_register(spml_ucx_ctx_progress);
645+
}
646+
647+
SHMEM_MUTEX_LOCK(mca_spml_ucx.internal_mutex);
648+
_ctx_add(&mca_spml_ucx.active_array, ucx_ctx);
649+
SHMEM_MUTEX_UNLOCK(mca_spml_ucx.internal_mutex);
650+
651+
(*ctx) = (shmem_ctx_t)ucx_ctx;
652+
return OSHMEM_SUCCESS;
653+
}
654+
637655
void mca_spml_ucx_ctx_destroy(shmem_ctx_t ctx)
638656
{
639657
MCA_SPML_CALL(quiet(ctx));
@@ -748,6 +766,15 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx)
748766
oshmem_shmem_abort(-1);
749767
return ret;
750768
}
769+
770+
/* If put_all_nb op/s is/are being executed asynchronously, need to wait its
771+
* completion as well. */
772+
if (ctx == oshmem_ctx_default) {
773+
while (mca_spml_ucx.aux_refcnt) {
774+
opal_progress();
775+
}
776+
}
777+
751778
return OSHMEM_SUCCESS;
752779
}
753780

@@ -785,3 +812,99 @@ int mca_spml_ucx_send(void* buf,
785812

786813
return rc;
787814
}
815+
816+
static void mca_spml_ucx_put_all_complete_cb(void *request, ucs_status_t status)
817+
{
818+
if (mca_spml_ucx.async_progress && (--mca_spml_ucx.aux_refcnt == 0)) {
819+
opal_event_evtimer_del(mca_spml_ucx.tick_event);
820+
opal_progress_unregister(spml_ucx_progress_aux_ctx);
821+
}
822+
823+
if (request != NULL) {
824+
ucp_request_free(request);
825+
}
826+
}
827+
828+
/* Should be called with AUX lock taken */
829+
static int mca_spml_ucx_create_aux_ctx(void)
830+
{
831+
unsigned major = 0;
832+
unsigned minor = 0;
833+
unsigned rel_number = 0;
834+
int rc;
835+
bool rand_dci_supp;
836+
837+
ucp_get_version(&major, &minor, &rel_number);
838+
rand_dci_supp = UCX_VERSION(major, minor, rel_number) >= UCX_VERSION(1, 6, 0);
839+
840+
if (rand_dci_supp) {
841+
opal_setenv("UCX_DC_TX_POLICY", "rand", 1, &environ);
842+
opal_setenv("UCX_DC_MLX5_TX_POLICY", "rand", 1, &environ);
843+
}
844+
845+
rc = mca_spml_ucx_ctx_create_common(SHMEM_CTX_PRIVATE, &mca_spml_ucx.aux_ctx);
846+
847+
if (rand_dci_supp) {
848+
opal_unsetenv("UCX_DC_TX_POLICY", &environ);
849+
opal_unsetenv("UCX_DC_MLX5_TX_POLICY", &environ);
850+
}
851+
852+
return rc;
853+
}
854+
855+
int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *counter)
856+
{
857+
int my_pe = oshmem_my_proc_id();
858+
long val = 1;
859+
int peer, dst_pe, rc;
860+
shmem_ctx_t ctx;
861+
struct timeval tv;
862+
void *request;
863+
864+
mca_spml_ucx_aux_lock();
865+
if (mca_spml_ucx.async_progress) {
866+
if (mca_spml_ucx.aux_ctx == NULL) {
867+
rc = mca_spml_ucx_create_aux_ctx();
868+
if (rc != OMPI_SUCCESS) {
869+
mca_spml_ucx_aux_unlock();
870+
oshmem_shmem_abort(-1);
871+
}
872+
}
873+
874+
if (!mca_spml_ucx.aux_refcnt) {
875+
tv.tv_sec = 0;
876+
tv.tv_usec = mca_spml_ucx.async_tick;
877+
opal_event_evtimer_add(mca_spml_ucx.tick_event, &tv);
878+
opal_progress_register(spml_ucx_progress_aux_ctx);
879+
}
880+
ctx = (shmem_ctx_t)mca_spml_ucx.aux_ctx;
881+
++mca_spml_ucx.aux_refcnt;
882+
} else {
883+
ctx = oshmem_ctx_default;
884+
}
885+
886+
for (peer = 0; peer < oshmem_num_procs(); peer++) {
887+
dst_pe = (peer + my_pe) % oshmem_group_all->proc_count;
888+
rc = mca_spml_ucx_put_nb(ctx,
889+
(void*)((uintptr_t)dest + my_pe * size),
890+
size,
891+
(void*)((uintptr_t)source + dst_pe * size),
892+
dst_pe, NULL);
893+
RUNTIME_CHECK_RC(rc);
894+
895+
mca_spml_ucx_fence(ctx);
896+
897+
rc = MCA_ATOMIC_CALL(add(ctx, (void*)counter, val, sizeof(val), dst_pe));
898+
RUNTIME_CHECK_RC(rc);
899+
}
900+
901+
request = ucp_worker_flush_nb(((mca_spml_ucx_ctx_t*)ctx)->ucp_worker, 0,
902+
mca_spml_ucx_put_all_complete_cb);
903+
if (!UCS_PTR_IS_PTR(request)) {
904+
mca_spml_ucx_put_all_complete_cb(NULL, UCS_PTR_STATUS(request));
905+
}
906+
907+
mca_spml_ucx_aux_unlock();
908+
909+
return OSHMEM_SUCCESS;
910+
}

0 commit comments

Comments
 (0)