Skip to content

Commit 7c3aeb3

Browse files
authored
Merge pull request #6686 from alex-anenkov/coll-iallreduce-recursivedoubling
coll/libnbc: add recursive doubling algorithm for MPI_Iallreduce
2 parents 80e0ac7 + 77d466e commit 7c3aeb3

File tree

2 files changed

+168
-5
lines changed

2 files changed

+168
-5
lines changed

ompi/mca/coll/libnbc/coll_libnbc_component.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ static mca_base_var_enum_value_t iallreduce_algorithms[] = {
6060
{1, "ring"},
6161
{2, "binomial"},
6262
{3, "rabenseifner"},
63+
{4, "recursive_doubling"},
6364
{0, NULL}
6465
};
6566

@@ -225,7 +226,7 @@ libnbc_register(void)
225226
(void) mca_base_var_enum_create("coll_libnbc_iallreduce_algorithms", iallreduce_algorithms, &new_enum);
226227
mca_base_component_var_register(&mca_coll_libnbc_component.super.collm_version,
227228
"iallreduce_algorithm",
228-
"Which iallreduce algorithm is used: 0 ignore, 1 ring, 2 binomial, 3 rabenseifner",
229+
"Which iallreduce algorithm is used: 0 ignore, 1 ring, 2 binomial, 3 rabenseifner, 4 recursive_doubling",
229230
MCA_BASE_VAR_TYPE_INT, new_enum, 0, MCA_BASE_VAR_FLAG_SETTABLE,
230231
OPAL_INFO_LVL_5, MCA_BASE_VAR_SCOPE_ALL,
231232
&libnbc_iallreduce_algorithm);

ompi/mca/coll/libnbc/nbc_iallreduce.c

Lines changed: 166 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
static inline int allred_sched_diss(int rank, int p, int count, MPI_Datatype datatype, ptrdiff_t gap, const void *sendbuf,
3030
void *recvbuf, MPI_Op op, char inplace, NBC_Schedule *schedule, void *tmpbuf);
31+
static inline int allred_sched_recursivedoubling(int rank, int p, const void *sendbuf, void *recvbuf,
32+
int count, MPI_Datatype datatype, ptrdiff_t gap, MPI_Op op,
33+
char inplace, NBC_Schedule *schedule, void *tmpbuf);
3134
static inline int allred_sched_ring(int rank, int p, int count, MPI_Datatype datatype, const void *sendbuf,
3235
void *recvbuf, MPI_Op op, int size, int ext, NBC_Schedule *schedule,
3336
void *tmpbuf);
@@ -69,7 +72,7 @@ static int nbc_allreduce_init(const void* sendbuf, void* recvbuf, int count, MPI
6972
#ifdef NBC_CACHE_SCHEDULE
7073
NBC_Allreduce_args *args, *found, search;
7174
#endif
72-
enum { NBC_ARED_BINOMIAL, NBC_ARED_RING, NBC_ARED_REDSCAT_ALLGATHER } alg;
75+
enum { NBC_ARED_BINOMIAL, NBC_ARED_RING, NBC_ARED_REDSCAT_ALLGATHER, NBC_ARED_RDBL } alg;
7376
char inplace;
7477
void *tmpbuf = NULL;
7578
ompi_coll_libnbc_module_t *libnbc_module = (ompi_coll_libnbc_module_t*) module;
@@ -124,9 +127,11 @@ static int nbc_allreduce_init(const void* sendbuf, void* recvbuf, int count, MPI
124127
alg = NBC_ARED_RING;
125128
else if (libnbc_iallreduce_algorithm == 2)
126129
alg = NBC_ARED_BINOMIAL;
127-
else if (libnbc_iallreduce_algorithm == 3 && count >= nprocs_pof2 && ompi_op_is_commute(op)) {
130+
else if (libnbc_iallreduce_algorithm == 3 && count >= nprocs_pof2 && ompi_op_is_commute(op))
128131
alg = NBC_ARED_REDSCAT_ALLGATHER;
129-
} else
132+
else if (libnbc_iallreduce_algorithm == 4)
133+
alg = NBC_ARED_RDBL;
134+
else
130135
alg = NBC_ARED_RING;
131136
}
132137
#ifdef NBC_CACHE_SCHEDULE
@@ -159,6 +164,9 @@ static int nbc_allreduce_init(const void* sendbuf, void* recvbuf, int count, MPI
159164
case NBC_ARED_RING:
160165
res = allred_sched_ring(rank, p, count, datatype, sendbuf, recvbuf, op, size, ext, schedule, tmpbuf);
161166
break;
167+
case NBC_ARED_RDBL:
168+
res = allred_sched_recursivedoubling(rank, p, sendbuf, recvbuf, count, datatype, gap, op, inplace, schedule, tmpbuf);
169+
break;
162170
}
163171
}
164172

@@ -470,6 +478,161 @@ static inline int allred_sched_diss(int rank, int p, int count, MPI_Datatype dat
470478
return OMPI_SUCCESS;
471479
}
472480

481+
/*
482+
* allred_sched_recursivedoubling
483+
*
484+
* Function: Recursive doubling algorithm for iallreduce operation
485+
*
486+
* Description: Implements recursive doubling algorithm for iallreduce.
487+
* The algorithm preserves order of operations so it can
488+
* be used both by commutative and non-commutative operations.
489+
* Schedule length: O(\log(p))
490+
* Memory requirements:
491+
* Each process requires a temporary buffer: count * typesize = O(count)
492+
*
493+
* Example on 7 nodes:
494+
* Initial state
495+
* # 0 1 2 3 4 5 6
496+
* [0] [1] [2] [3] [4] [5] [6]
497+
* Initial adjustment step for non-power of two nodes.
498+
* old rank 1 3 5 6
499+
* new rank 0 1 2 3
500+
* [0+1] [2+3] [4+5] [6]
501+
* Step 1
502+
* old rank 1 3 5 6
503+
* new rank 0 1 2 3
504+
* [0+1+] [0+1+] [4+5+] [4+5+]
505+
* [2+3+] [2+3+] [6 ] [6 ]
506+
* Step 2
507+
* old rank 1 3 5 6
508+
* new rank 0 1 2 3
509+
* [0+1+] [0+1+] [0+1+] [0+1+]
510+
* [2+3+] [2+3+] [2+3+] [2+3+]
511+
* [4+5+] [4+5+] [4+5+] [4+5+]
512+
* [6 ] [6 ] [6 ] [6 ]
513+
* Final adjustment step for non-power of two nodes
514+
* # 0 1 2 3 4 5 6
515+
* [0+1+] [0+1+] [0+1+] [0+1+] [0+1+] [0+1+] [0+1+]
516+
* [2+3+] [2+3+] [2+3+] [2+3+] [2+3+] [2+3+] [2+3+]
517+
* [4+5+] [4+5+] [4+5+] [4+5+] [4+5+] [4+5+] [4+5+]
518+
* [6 ] [6 ] [6 ] [6 ] [6 ] [6 ] [6 ]
519+
*
520+
*/
521+
static inline int allred_sched_recursivedoubling(int rank, int p, const void *sendbuf, void *recvbuf,
522+
int count, MPI_Datatype datatype, ptrdiff_t gap, MPI_Op op,
523+
char inplace, NBC_Schedule *schedule, void *tmpbuf)
524+
{
525+
int res, pof2, nprocs_rem, vrank;
526+
char *tmpsend = NULL, *tmprecv = NULL, *tmpswap = NULL;
527+
528+
tmpsend = (char*) tmpbuf - gap;
529+
tmprecv = (char*) recvbuf;
530+
531+
if (inplace) {
532+
res = NBC_Sched_copy(recvbuf, false, count, datatype,
533+
tmpsend, false, count, datatype, schedule, true);
534+
} else {
535+
res = NBC_Sched_copy((void *)sendbuf, false, count, datatype,
536+
tmpsend, false, count, datatype, schedule, true);
537+
}
538+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
539+
540+
/* Get nearest power of two less than or equal to comm size */
541+
pof2 = opal_next_poweroftwo(p) >> 1;
542+
543+
/* Handle non-power-of-two case:
544+
- Even ranks less than 2 * nprocs_rem send their data to (rank + 1), and
545+
sets new rank to -1.
546+
- Odd ranks less than 2 * nprocs_rem receive data from (rank - 1),
547+
apply appropriate operation, and set new rank to rank/2
548+
- Everyone else sets rank to rank - nprocs_rem
549+
*/
550+
nprocs_rem = p - pof2;
551+
if (rank < 2 * nprocs_rem) {
552+
if (0 == rank % 2) { /* Even */
553+
res = NBC_Sched_send(tmpsend, false, count, datatype, rank + 1, schedule, true);
554+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
555+
vrank = -1;
556+
} else { /* Odd */
557+
res = NBC_Sched_recv(tmprecv, false, count, datatype, rank - 1, schedule, true);
558+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
559+
560+
/* tmpsend = tmprecv (op) tmpsend */
561+
res = NBC_Sched_op(tmprecv, false, tmpsend, false, count, datatype, op, schedule, true);
562+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
563+
564+
vrank = rank >> 1;
565+
}
566+
} else {
567+
vrank = rank - nprocs_rem;
568+
}
569+
570+
/* Communication/Computation loop
571+
- Exchange message with remote node.
572+
- Perform appropriate operation taking in account order of operations:
573+
result = value (op) result
574+
*/
575+
if (0 <= vrank) {
576+
for (int distance = 1; distance < pof2; distance <<= 1) {
577+
int remote = vrank ^ distance;
578+
579+
/* Find real rank of remote node */
580+
if (remote < nprocs_rem) {
581+
remote = remote * 2 + 1;
582+
} else {
583+
remote += nprocs_rem;
584+
}
585+
586+
/* Exchange the data */
587+
res = NBC_Sched_send(tmpsend, false, count, datatype, remote, schedule, false);
588+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
589+
590+
res = NBC_Sched_recv(tmprecv, false, count, datatype, remote, schedule, true);
591+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
592+
593+
/* Apply operation */
594+
if (rank < remote) {
595+
/* tmprecv = tmpsend (op) tmprecv */
596+
res = NBC_Sched_op(tmpsend, false, tmprecv, false,
597+
count, datatype, op, schedule, true);
598+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
599+
600+
/* Swap tmpsend and tmprecv buffers */
601+
tmpswap = tmprecv; tmprecv = tmpsend; tmpsend = tmpswap;
602+
} else {
603+
/* tmpsend = tmprecv (op) tmpsend */
604+
res = NBC_Sched_op(tmprecv, false, tmpsend, false,
605+
count, datatype, op, schedule, true);
606+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
607+
}
608+
}
609+
}
610+
611+
/* Handle non-power-of-two case:
612+
- Even ranks less than 2 * nprocs_rem receive result from (rank + 1)
613+
- Odd ranks less than 2 * nprocs_rem send result from tmpsend to (rank - 1)
614+
*/
615+
if (rank < 2 * nprocs_rem) {
616+
if (0 == rank % 2) { /* Even */
617+
res = NBC_Sched_recv(recvbuf, false, count, datatype, rank + 1, schedule, false);
618+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
619+
tmpsend = (char *)recvbuf;
620+
} else { /* Odd */
621+
res = NBC_Sched_send(tmpsend, false, count, datatype, rank - 1, schedule, false);
622+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
623+
}
624+
}
625+
626+
/* Copy result back into recvbuf */
627+
if (tmpsend != recvbuf) {
628+
res = NBC_Sched_copy(tmpsend, false, count, datatype,
629+
recvbuf, false, count, datatype, schedule, false);
630+
if (OPAL_UNLIKELY(OMPI_SUCCESS != res)) { return res; }
631+
}
632+
633+
return OMPI_SUCCESS;
634+
}
635+
473636
static inline int allred_sched_ring (int r, int p, int count, MPI_Datatype datatype, const void *sendbuf, void *recvbuf, MPI_Op op,
474637
int size, int ext, NBC_Schedule *schedule, void *tmpbuf) {
475638
int segsize, *segsizes, *segoffsets; /* segment sizes and offsets per segment (number of segments == number of nodes */
@@ -1044,4 +1207,3 @@ int ompi_coll_libnbc_allreduce_inter_init(const void* sendbuf, void* recvbuf, in
10441207

10451208
return OMPI_SUCCESS;
10461209
}
1047-

0 commit comments

Comments
 (0)