Skip to content

Commit f69f03a

Browse files
committed
mtl/ofi: Update initialization to add FI_HMEM, FI_MR_HMEM, and FI_MR_LOCAL support
Add a check to see if Libfabric has at least one provider with FI_HMEM support, use this info to set whether or not Libfabric has CUDA support. Add provider hints for FI_MR_LOCAL, and if Libfabric has CUDA support, also add hints for FI_HMEM and FI_MR_HMEM. In the case where Open MPI is built with CUDA support but Libfabric is not, the MTL/OFI is not picked. Signed-off-by: William Zhang <wilzhang@amazon.com>
1 parent 9a6c0f4 commit f69f03a

File tree

1 file changed

+41
-9
lines changed

1 file changed

+41
-9
lines changed

ompi/mca/mtl/ofi/mtl_ofi_component.c

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
#include "opal/util/argv.h"
2121
#include "opal/util/printf.h"
2222
#include "opal/mca/common/ofi/common_ofi.h"
23+
#if OPAL_CUDA_SUPPORT
24+
#include "opal/mca/common/cuda/common_cuda.h"
25+
#endif /* OPAL_CUDA_SUPPORT */
2326

2427
static int ompi_mtl_ofi_component_open(void);
2528
static int ompi_mtl_ofi_component_query(mca_base_module_t **module, int *priority);
@@ -297,6 +300,9 @@ ompi_mtl_ofi_component_query(mca_base_module_t **module, int *priority)
297300
static int
298301
ompi_mtl_ofi_component_close(void)
299302
{
303+
#if OPAL_CUDA_SUPPORT
304+
mca_common_cuda_fini();
305+
#endif
300306
opal_common_ofi_mca_deregister();
301307
return OMPI_SUCCESS;
302308
}
@@ -591,6 +597,15 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
591597
exclude_list = opal_argv_split(*opal_common_ofi.prov_exclude, ',');
592598
}
593599

600+
/**
601+
* Note: API version 1.5 is the first version that supports
602+
* FI_LOCAL_COMM / FI_REMOTE_COMM checking (and we definitely need
603+
* that checking -- e.g., the shared memory provider supports
604+
* intranode communication (FI_LOCAL_COMM), but not internode
605+
* (FI_REMOTE_COMM), which is insufficient for MTL selection.
606+
*/
607+
fi_version = FI_VERSION(1, 5);
608+
594609
/**
595610
* Hints to filter providers
596611
* See man fi_getinfo for a list of all filters
@@ -608,11 +623,23 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
608623
__FILE__, __LINE__);
609624
goto error;
610625
}
626+
627+
#if OPAL_CUDA_SUPPORT
628+
/** If Open MPI is built with CUDA, request device transfer
629+
* capabilities */
630+
hints->caps |= FI_HMEM;
631+
hints->domain_attr->mr_mode |= FI_MR_HMEM;
632+
/**
633+
* Note: API version 1.9 is the first version that supports FI_HMEM
634+
*/
635+
fi_version = FI_VERSION(1, 9);
636+
#endif /* OPAL_CUDA_SUPPORT */
637+
611638
/* Make sure to get a RDM provider that can do the tagged matching
612639
interface and local communication and remote communication. */
613640
hints->mode = FI_CONTEXT;
614641
hints->ep_attr->type = FI_EP_RDM;
615-
hints->caps = FI_TAGGED | FI_LOCAL_COMM | FI_REMOTE_COMM | FI_DIRECTED_RECV;
642+
hints->caps |= FI_TAGGED | FI_LOCAL_COMM | FI_REMOTE_COMM | FI_DIRECTED_RECV;
616643
hints->tx_attr->msg_order = FI_ORDER_SAS;
617644
hints->rx_attr->msg_order = FI_ORDER_SAS;
618645
hints->rx_attr->op_flags = FI_COMPLETION;
@@ -660,14 +687,6 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
660687

661688
hints->domain_attr->resource_mgmt = FI_RM_ENABLED;
662689

663-
/**
664-
* Note: API version 1.5 is the first version that supports
665-
* FI_LOCAL_COMM / FI_REMOTE_COMM checking (and we definitely need
666-
* that checking -- e.g., some providers are suitable for RXD or
667-
* RXM, but can't provide local communication).
668-
*/
669-
fi_version = FI_VERSION(1, 5);
670-
671690
/**
672691
* The EFA provider in Libfabric versions prior to 1.10 contains a bug
673692
* where the FI_LOCAL_COMM and FI_REMOTE_COMM capabilities are not
@@ -758,6 +777,15 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
758777
opal_argv_free(exclude_list);
759778
exclude_list = NULL;
760779

780+
#if OPAL_CUDA_SUPPORT
781+
if (!(prov->caps & FI_HMEM)) {
782+
opal_output_verbose(1, opal_common_ofi.output,
783+
"%s:%d: Libfabric provider does not support CUDA buffers\n",
784+
__FILE__, __LINE__);
785+
goto error;
786+
}
787+
#endif /* OPAL_CUDA_SUPPORT */
788+
761789
/**
762790
* Select the format of the OFI tag
763791
*/
@@ -1033,6 +1061,10 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
10331061
*/
10341062
ompi_mtl_ofi.any_addr = FI_ADDR_UNSPEC;
10351063

1064+
#if OPAL_CUDA_SUPPORT
1065+
mca_common_cuda_stage_one_init();
1066+
#endif
1067+
10361068
return &ompi_mtl_ofi.base;
10371069

10381070
error:

0 commit comments

Comments
 (0)