Skip to content

Commit ba83cc9

Browse files
mkurnosovhjelmn
authored andcommitted
coll/base: add MPI_Bcast based on a binomial tree scatter followed by a ring allgather
Implements MPI_Bcast using a binomial tree scatter followed by a ring allgather. Signed-off-by: Mikhail Kurnosov <mkurnosov@gmail.com>
1 parent 27b91d7 commit ba83cc9

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

ompi/mca/coll/base/coll_base_bcast.c

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,3 +891,134 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
891891
cleanup_and_return:
892892
return err;
893893
}
894+
895+
/*
896+
* ompi_coll_base_bcast_intra_scatter_allgather_ring
897+
*
898+
* Function: Bcast using a binomial tree scatter followed by a ring allgather.
899+
* Accepts: Same arguments as MPI_Bcast
900+
* Returns: MPI_SUCCESS or error code
901+
*
902+
* Limitations: count >= comm_size
903+
* Time complexity: O(\alpha(\log(p) + p) + \beta*m((p-1)/p))
904+
* Binomial tree scatter: \alpha\log(p) + \beta*m((p-1)/p)
905+
* Ring allgather: 2(p-1)(\alpha + m/p\beta)
906+
*
907+
* Example, p=8, count=8, root=0
908+
* Binomial tree scatter Ring allgather: p - 1 steps
909+
* 0: --+ --+ --+ [0*******] [0******7] [0*****67] [0****567] ... [01234567]
910+
* 1: | 2| <-+ [*1******] [01******] [01*****7] [01****67] ... [01234567]
911+
* 2: 4| <-+ --+ [**2*****] [*12*****] [012*****] [012****7] ... [01234567]
912+
* 3: | <-+ [***3****] [**23****] [*123****] [0123****] ... [01234567]
913+
* 4: <-+ --+ --+ [****4***] [***34***] [**234***] [*1234***] ... [01234567]
914+
* 5: 2| <-+ [*****5**] [****45**] [***345**] [**2345**] ... [01234567]
915+
* 6: <-+ --+ [******6*] [*****56*] [****456*] [***3456*] ... [01234567]
916+
* 7: <-+ [*******7] [******67] [*****567] [****4567] ... [01234567]
917+
*/
918+
int ompi_coll_base_bcast_intra_scatter_allgather_ring(
919+
void *buf, int count, struct ompi_datatype_t *datatype, int root,
920+
struct ompi_communicator_t *comm, mca_coll_base_module_t *module,
921+
uint32_t segsize)
922+
{
923+
int err = MPI_SUCCESS;
924+
ptrdiff_t lb, extent;
925+
size_t datatype_size;
926+
MPI_Status status;
927+
ompi_datatype_get_extent(datatype, &lb, &extent);
928+
ompi_datatype_type_size(datatype, &datatype_size);
929+
int comm_size = ompi_comm_size(comm);
930+
int rank = ompi_comm_rank(comm);
931+
932+
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
933+
"coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d",
934+
rank, comm_size));
935+
if (comm_size < 2 || datatype_size == 0)
936+
return MPI_SUCCESS;
937+
938+
if (count < comm_size) {
939+
OPAL_OUTPUT((ompi_coll_base_framework.framework_output,
940+
"coll:base:bcast_intra_scatter_allgather_ring: rank %d/%d "
941+
"count %d switching to basic linear bcast",
942+
rank, comm_size, count));
943+
return ompi_coll_base_bcast_intra_basic_linear(buf, count, datatype,
944+
root, comm, module);
945+
}
946+
947+
int vrank = (rank - root + comm_size) % comm_size;
948+
int recv_count = 0, send_count = 0;
949+
int scatter_count = (count + comm_size - 1) / comm_size; /* ceil(count / comm_size) */
950+
int curr_count = (rank == root) ? count : 0;
951+
952+
/* Scatter by binomial tree: receive data from parent */
953+
int mask = 1;
954+
while (mask < comm_size) {
955+
if (vrank & mask) {
956+
int parent = (rank - mask + comm_size) % comm_size;
957+
/* Compute an upper bound on recv block size */
958+
recv_count = count - vrank * scatter_count;
959+
if (recv_count <= 0) {
960+
curr_count = 0;
961+
} else {
962+
/* Recv data from parent */
963+
err = MCA_PML_CALL(recv((char *)buf + (ptrdiff_t)vrank * scatter_count * extent,
964+
recv_count, datatype, parent,
965+
MCA_COLL_BASE_TAG_BCAST, comm, &status));
966+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
967+
/* Get received count */
968+
curr_count = (int)(status._ucount / datatype_size);
969+
}
970+
break;
971+
}
972+
mask <<= 1;
973+
}
974+
975+
/* Scatter by binomial tree: send data to child processes */
976+
mask >>= 1;
977+
while (mask > 0) {
978+
if (vrank + mask < comm_size) {
979+
send_count = curr_count - scatter_count * mask;
980+
if (send_count > 0) {
981+
int child = (rank + mask) % comm_size;
982+
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
983+
send_count, datatype, child,
984+
MCA_COLL_BASE_TAG_BCAST,
985+
MCA_PML_BASE_SEND_STANDARD, comm));
986+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
987+
curr_count -= send_count;
988+
}
989+
}
990+
mask >>= 1;
991+
}
992+
993+
/* Allgather by a ring algorithm */
994+
int left = (rank - 1 + comm_size) % comm_size;
995+
int right = (rank + 1) % comm_size;
996+
int send_block = vrank;
997+
int recv_block = (vrank - 1 + comm_size) % comm_size;
998+
999+
for (int i = 1; i < comm_size; i++) {
1000+
recv_count = (scatter_count < count - recv_block * scatter_count) ?
1001+
scatter_count : count - recv_block * scatter_count;
1002+
if (recv_count < 0)
1003+
recv_count = 0;
1004+
ptrdiff_t recv_offset = recv_block * scatter_count * extent;
1005+
1006+
send_count = (scatter_count < count - send_block * scatter_count) ?
1007+
scatter_count : count - send_block * scatter_count;
1008+
if (send_count < 0)
1009+
send_count = 0;
1010+
ptrdiff_t send_offset = send_block * scatter_count * extent;
1011+
1012+
err = ompi_coll_base_sendrecv((char *)buf + send_offset, send_count,
1013+
datatype, right, MCA_COLL_BASE_TAG_BCAST,
1014+
(char *)buf + recv_offset, recv_count,
1015+
datatype, left, MCA_COLL_BASE_TAG_BCAST,
1016+
comm, MPI_STATUS_IGNORE, rank);
1017+
if (MPI_SUCCESS != err) { goto cleanup_and_return; }
1018+
send_block = recv_block;
1019+
recv_block = (recv_block - 1 + comm_size) % comm_size;
1020+
}
1021+
1022+
cleanup_and_return:
1023+
return err;
1024+
}

ompi/mca/coll/base/coll_base_functions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ int ompi_coll_base_bcast_intra_bintree(BCAST_ARGS, uint32_t segsize);
247247
int ompi_coll_base_bcast_intra_split_bintree(BCAST_ARGS, uint32_t segsize);
248248
int ompi_coll_base_bcast_intra_knomial(BCAST_ARGS, uint32_t segsize, int radix);
249249
int ompi_coll_base_bcast_intra_scatter_allgather(BCAST_ARGS, uint32_t segsize);
250+
int ompi_coll_base_bcast_intra_scatter_allgather_ring(BCAST_ARGS, uint32_t segsize);
250251

251252
/* Exscan */
252253
int ompi_coll_base_exscan_intra_recursivedoubling(EXSCAN_ARGS);

ompi/mca/coll/tuned/coll_tuned_bcast_decision.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ static mca_base_var_enum_value_t bcast_algorithms[] = {
4444
{6, "binomial"},
4545
{7, "knomial"},
4646
{8, "scatter_allgather"},
47+
{9, "scatter_allgather_ring"},
4748
{0, NULL}
4849
};
4950

@@ -79,7 +80,7 @@ int ompi_coll_tuned_bcast_intra_check_forced_init (coll_tuned_force_algorithm_mc
7980
mca_param_indices->algorithm_param_index =
8081
mca_base_component_var_register(&mca_coll_tuned_component.super.collm_version,
8182
"bcast_algorithm",
82-
"Which bcast algorithm is used. Can be locked down to choice of: 0 ignore, 1 basic linear, 2 chain, 3: pipeline, 4: split binary tree, 5: binary tree, 6: binomial tree, 7: knomial tree, 8: scatter_allgather.",
83+
"Which bcast algorithm is used. Can be locked down to choice of: 0 ignore, 1 basic linear, 2 chain, 3: pipeline, 4: split binary tree, 5: binary tree, 6: binomial tree, 7: knomial tree, 8: scatter_allgather, 9: scatter_allgather_ring.",
8384
MCA_BASE_VAR_TYPE_INT, new_enum, 0, MCA_BASE_VAR_FLAG_SETTABLE,
8485
OPAL_INFO_LVL_5,
8586
MCA_BASE_VAR_SCOPE_ALL,
@@ -160,6 +161,8 @@ int ompi_coll_tuned_bcast_intra_do_this(void *buf, int count,
160161
segsize, coll_tuned_bcast_knomial_radix);
161162
case (8):
162163
return ompi_coll_base_bcast_intra_scatter_allgather(buf, count, dtype, root, comm, module, segsize);
164+
case (9):
165+
return ompi_coll_base_bcast_intra_scatter_allgather_ring(buf, count, dtype, root, comm, module, segsize);
163166
} /* switch */
164167
OPAL_OUTPUT((ompi_coll_tuned_stream,"coll:tuned:bcast_intra_do_this attempt to select algorithm %d when only 0-%d is valid?",
165168
algorithm, ompi_coll_tuned_forced_max_algorithms[BCAST]));

0 commit comments

Comments
 (0)