Skip to content

Commit c3b6852

Browse files
authored
Merge pull request #12223 from jiaxiyan/disjoint
communicator: fix max_local_peers value in disjoint function
2 parents 223ed58 + 23df181 commit c3b6852

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

ompi/communicator/comm_cid.c

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ opal_atomic_int64_t ompi_comm_next_base_cid = 1;
6565

6666
struct ompi_comm_cid_context_t;
6767

68-
typedef int (*ompi_comm_allreduce_impl_fn_t) (int *inbuf, int *outbuf, int count, struct ompi_op_t *op,
69-
struct ompi_comm_cid_context_t *cid_context,
70-
ompi_request_t **req);
68+
typedef int (*ompi_comm_iallreduce_impl_fn_t) (int *inbuf, int *outbuf, int count, struct ompi_op_t *op,
69+
struct ompi_comm_cid_context_t *cid_context,
70+
ompi_request_t **req);
7171

7272

7373
struct ompi_comm_cid_context_t {
@@ -78,7 +78,7 @@ struct ompi_comm_cid_context_t {
7878
ompi_communicator_t *comm;
7979
ompi_communicator_t *bridgecomm;
8080

81-
ompi_comm_allreduce_impl_fn_t allreduce_fn;
81+
ompi_comm_iallreduce_impl_fn_t iallreduce_fn;
8282

8383
int nextcid;
8484
int nextlocal_cid;
@@ -225,38 +225,38 @@ static ompi_comm_cid_context_t *mca_comm_cid_context_alloc (ompi_communicator_t
225225
* for the current mode. */
226226
switch (mode) {
227227
case OMPI_COMM_CID_INTRA:
228-
context->allreduce_fn = ompi_comm_allreduce_intra_nb;
228+
context->iallreduce_fn = ompi_comm_allreduce_intra_nb;
229229
break;
230230
case OMPI_COMM_CID_INTER:
231-
context->allreduce_fn = ompi_comm_allreduce_inter_nb;
231+
context->iallreduce_fn = ompi_comm_allreduce_inter_nb;
232232
break;
233233
case OMPI_COMM_CID_GROUP:
234234
case OMPI_COMM_CID_GROUP_NEW:
235-
context->allreduce_fn = ompi_comm_allreduce_group_nb;
235+
context->iallreduce_fn = ompi_comm_allreduce_group_nb;
236236
context->pml_tag = ((int *) arg0)[0];
237237
break;
238238
case OMPI_COMM_CID_INTRA_PMIX:
239-
context->allreduce_fn = ompi_comm_allreduce_intra_pmix_nb;
239+
context->iallreduce_fn = ompi_comm_allreduce_intra_pmix_nb;
240240
context->local_leader = ((int *) arg0)[0];
241241
if (arg1) {
242242
context->port_string = strdup ((char *) arg1);
243243
}
244244
context->pmix_tag = strdup ((char *) pmix_tag);
245245
break;
246246
case OMPI_COMM_CID_INTRA_BRIDGE:
247-
context->allreduce_fn = ompi_comm_allreduce_intra_bridge_nb;
247+
context->iallreduce_fn = ompi_comm_allreduce_intra_bridge_nb;
248248
context->local_leader = ((int *) arg0)[0];
249249
context->remote_leader = ((int *) arg1)[0];
250250
break;
251251
#if OPAL_ENABLE_FT_MPI
252252
case OMPI_COMM_CID_INTRA_FT:
253-
context->allreduce_fn = ompi_comm_ft_allreduce_intra_nb;
253+
context->iallreduce_fn = ompi_comm_ft_allreduce_intra_nb;
254254
break;
255255
case OMPI_COMM_CID_INTER_FT:
256-
context->allreduce_fn = ompi_comm_ft_allreduce_inter_nb;
256+
context->iallreduce_fn = ompi_comm_ft_allreduce_inter_nb;
257257
break;
258258
case OMPI_COMM_CID_INTRA_PMIX_FT:
259-
context->allreduce_fn = ompi_comm_ft_allreduce_intra_pmix_nb;
259+
context->iallreduce_fn = ompi_comm_ft_allreduce_intra_pmix_nb;
260260
break;
261261
#endif /* OPAL_ENABLE_FT_MPI */
262262
default:
@@ -600,8 +600,8 @@ static int ompi_comm_allreduce_getnextcid (ompi_comm_request_t *request)
600600
#endif /* OPAL_ENABLE_FT_MPI */
601601
}
602602

603-
ret = context->allreduce_fn (&context->nextlocal_cid, &context->nextcid, 1, MPI_MAX,
604-
context, &subreq);
603+
ret = context->iallreduce_fn (&context->nextlocal_cid, &context->nextcid, 1, MPI_MAX,
604+
context, &subreq);
605605
/* there was a failure during non-blocking collective
606606
* all we can do is abort
607607
*/
@@ -666,7 +666,7 @@ static int ompi_comm_checkcid (ompi_comm_request_t *request)
666666

667667
++context->iter;
668668

669-
ret = context->allreduce_fn (&context->flag, &context->rflag, 1, MPI_MIN, context, &subreq);
669+
ret = context->iallreduce_fn (&context->flag, &context->rflag, 1, MPI_MIN, context, &subreq);
670670
if (OMPI_SUCCESS == ret) {
671671
ompi_comm_request_schedule_append (request, ompi_comm_nextcid_check_flag, &subreq, 1);
672672
} else {
@@ -774,6 +774,11 @@ static int ompi_comm_activate_nb_complete (ompi_comm_request_t *request);
774774
/* Callback function to set communicator disjointness flags */
775775
static inline void ompi_comm_set_disjointness_nb_complete(ompi_comm_cid_context_t *context)
776776
{
777+
/* Only set the disjoint flags when it is intra-communicator */
778+
if (OMPI_COMM_IS_INTER(*context->newcommp)) {
779+
return;
780+
}
781+
777782
if (OMPI_COMM_IS_DISJOINT_SET(*context->newcommp)) {
778783
opal_show_help("help-comm.txt", "disjointness-set-again", true);
779784
return;
@@ -870,7 +875,7 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c
870875
ompi_comm_cid_context_t *context;
871876
ompi_comm_request_t *request;
872877
ompi_request_t *subreq;
873-
int ret = 0, local_peers = -1;
878+
int ret = 0;
874879

875880
/* the caller should not pass NULL for comm (it may be the same as *newcomm) */
876881
assert (NULL != comm);
@@ -902,20 +907,19 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c
902907
OMPI_COMM_SET_PML_ADDED(*newcomm);
903908
}
904909

905-
/**
906-
* Dual-purpose barrier:
907-
* 1. The communicator's disjointness is inferred from max_local_peers.
908-
* 2. After the operation it is allowed to send messages over the new communicator.
909-
*/
910-
local_peers = context->max_local_peers;
911-
ret = context->allreduce_fn (&local_peers, &context->max_local_peers, 1, MPI_MAX, context,
912-
&subreq);
913-
if (OMPI_SUCCESS != ret) {
914-
ompi_comm_request_return (request);
915-
return ret;
910+
if (OMPI_COMM_IS_INTRA(*newcomm)) {
911+
/* The communicator's disjointness is inferred from max_local_peers. */
912+
ret = context->iallreduce_fn (MPI_IN_PLACE, &context->max_local_peers, 1, MPI_MAX, context,
913+
&subreq);
914+
if (OMPI_SUCCESS != ret) {
915+
ompi_comm_request_return (request);
916+
return ret;
917+
}
918+
ompi_comm_request_schedule_append (request, ompi_comm_activate_nb_complete, &subreq, 1);
919+
} else {
920+
ompi_comm_request_schedule_append (request, ompi_comm_activate_nb_complete, NULL, 0);
916921
}
917-
918-
ompi_comm_request_schedule_append (request, ompi_comm_activate_nb_complete, &subreq, 1);
922+
919923
ompi_comm_request_start (request);
920924

921925
*req = &request->super;

0 commit comments

Comments
 (0)