Skip to content

Commit efa293b

Browse files
committed
UCX: Propagate MPI serialized for all worker creations
Move MPI to UCX thread mode function to common source. Also use serialized mode for all oshmem initializations. Signed-off-by: Thomas Vegas <tvegas@nvidia.com>
1 parent 37b4d9d commit efa293b

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,20 +283,6 @@ int mca_pml_ucx_close(void)
283283
return OMPI_SUCCESS;
284284
}
285285

286-
static ucs_thread_mode_t mca_pml_ucx_thread_mode(int ompi_mode)
287-
{
288-
switch (ompi_mode) {
289-
case MPI_THREAD_MULTIPLE:
290-
return UCS_THREAD_MODE_MULTI;
291-
case MPI_THREAD_SERIALIZED:
292-
return UCS_THREAD_MODE_SERIALIZED;
293-
case MPI_THREAD_FUNNELED:
294-
case MPI_THREAD_SINGLE:
295-
default:
296-
return UCS_THREAD_MODE_SINGLE;
297-
}
298-
}
299-
300286
int mca_pml_ucx_init(int enable_mpi_threads)
301287
{
302288
ucp_worker_params_t params;
@@ -310,7 +296,8 @@ int mca_pml_ucx_init(int enable_mpi_threads)
310296
if (enable_mpi_threads) {
311297
params.thread_mode = UCS_THREAD_MODE_MULTI;
312298
} else {
313-
params.thread_mode = mca_pml_ucx_thread_mode(ompi_mpi_thread_provided);
299+
params.thread_mode =
300+
opal_common_ucx_thread_mode(ompi_mpi_thread_provided);
314301
}
315302

316303
#if HAVE_DECL_UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK

opal/mca/common/ucx/common_ucx.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
#include "opal/util/argv.h"
2828
#include "opal/util/printf.h"
2929

30+
#include "mpi.h"
31+
3032
#include <fnmatch.h>
3133
#include <stdio.h>
3234
#include <ucm/api/ucm.h>
@@ -50,6 +52,23 @@ static void opal_common_ucx_mem_release_cb(void *buf, size_t length, void *cbdat
5052
ucm_vm_munmap(buf, length);
5153
}
5254

55+
ucs_thread_mode_t opal_common_ucx_thread_mode(int ompi_mode)
56+
{
57+
switch (ompi_mode) {
58+
case MPI_THREAD_MULTIPLE:
59+
return UCS_THREAD_MODE_MULTI;
60+
case MPI_THREAD_SERIALIZED:
61+
return UCS_THREAD_MODE_SERIALIZED;
62+
case MPI_THREAD_FUNNELED:
63+
case MPI_THREAD_SINGLE:
64+
return UCS_THREAD_MODE_SINGLE;
65+
default:
66+
MCA_COMMON_UCX_WARN("Unknown MPI thread mode %d, using multithread",
67+
ompi_mode);
68+
return UCS_THREAD_MODE_MULTI;
69+
}
70+
}
71+
5372
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component)
5473
{
5574
char *default_tls = "rc_verbs,ud_verbs,rc_mlx5,dc_mlx5,ud_mlx5,cuda_ipc,rocm_ipc";

opal/mca/common/ucx/common_ucx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t *
128128
size_t my_rank, size_t max_disconnect,
129129
ucp_worker_h worker);
130130
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component);
131+
OPAL_DECLSPEC ucs_thread_mode_t opal_common_ucx_thread_mode(int ompi_mode);
131132

132133
/**
133134
* Load an integer value of \c size bytes from \c ptr and cast it to uint64_t.

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,8 +1047,11 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
10471047
ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync;
10481048

10491049
params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
1050-
if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE || options & SHMEM_CTX_PRIVATE || options & SHMEM_CTX_SERIALIZED) {
1050+
if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE ||
1051+
oshmem_mpi_thread_provided == SHMEM_THREAD_FUNNELED || options & SHMEM_CTX_PRIVATE) {
10511052
params.thread_mode = UCS_THREAD_MODE_SINGLE;
1053+
} else if (oshmem_mpi_thread_provided == SHMEM_THREAD_SERIALIZED || options & SHMEM_CTX_SERIALIZED) {
1054+
params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
10521055
} else {
10531056
params.thread_mode = UCS_THREAD_MODE_MULTI;
10541057
}

0 commit comments

Comments
 (0)