20
20
#include "opal/util/argv.h"
21
21
#include "opal/util/printf.h"
22
22
#include "opal/mca/common/ofi/common_ofi.h"
23
+ #if OPAL_CUDA_SUPPORT
24
+ #include "opal/mca/common/cuda/common_cuda.h"
25
+ #endif /* OPAL_CUDA_SUPPORT */
23
26
24
27
static int ompi_mtl_ofi_component_open (void );
25
28
static int ompi_mtl_ofi_component_query (mca_base_module_t * * module , int * priority );
@@ -297,6 +300,9 @@ ompi_mtl_ofi_component_query(mca_base_module_t **module, int *priority)
297
300
static int
298
301
ompi_mtl_ofi_component_close (void )
299
302
{
303
+ #if OPAL_CUDA_SUPPORT
304
+ mca_common_cuda_fini ();
305
+ #endif
300
306
opal_common_ofi_mca_deregister ();
301
307
return OMPI_SUCCESS ;
302
308
}
@@ -591,6 +597,15 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
591
597
exclude_list = opal_argv_split (* opal_common_ofi .prov_exclude , ',' );
592
598
}
593
599
600
+ /**
601
+ * Note: API version 1.5 is the first version that supports
602
+ * FI_LOCAL_COMM / FI_REMOTE_COMM checking (and we definitely need
603
+ * that checking -- e.g., the shared memory provider supports
604
+ * intranode communication (FI_LOCAL_COMM), but not internode
605
+ * (FI_REMOTE_COMM), which is insufficient for MTL selection.
606
+ */
607
+ fi_version = FI_VERSION (1 , 5 );
608
+
594
609
/**
595
610
* Hints to filter providers
596
611
* See man fi_getinfo for a list of all filters
@@ -608,11 +623,23 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
608
623
__FILE__ , __LINE__ );
609
624
goto error ;
610
625
}
626
+
627
+ #if OPAL_CUDA_SUPPORT
628
+ /** If Open MPI is built with CUDA, request device transfer
629
+ * capabilities */
630
+ hints -> caps |= FI_HMEM ;
631
+ hints -> domain_attr -> mr_mode |= FI_MR_HMEM ;
632
+ /**
633
+ * Note: API version 1.9 is the first version that supports FI_HMEM
634
+ */
635
+ fi_version = FI_VERSION (1 , 9 );
636
+ #endif /* OPAL_CUDA_SUPPORT */
637
+
611
638
/* Make sure to get a RDM provider that can do the tagged matching
612
639
interface and local communication and remote communication. */
613
640
hints -> mode = FI_CONTEXT ;
614
641
hints -> ep_attr -> type = FI_EP_RDM ;
615
- hints -> caps = FI_TAGGED | FI_LOCAL_COMM | FI_REMOTE_COMM | FI_DIRECTED_RECV ;
642
+ hints -> caps | = FI_TAGGED | FI_LOCAL_COMM | FI_REMOTE_COMM | FI_DIRECTED_RECV ;
616
643
hints -> tx_attr -> msg_order = FI_ORDER_SAS ;
617
644
hints -> rx_attr -> msg_order = FI_ORDER_SAS ;
618
645
hints -> rx_attr -> op_flags = FI_COMPLETION ;
@@ -660,14 +687,6 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
660
687
661
688
hints -> domain_attr -> resource_mgmt = FI_RM_ENABLED ;
662
689
663
- /**
664
- * Note: API version 1.5 is the first version that supports
665
- * FI_LOCAL_COMM / FI_REMOTE_COMM checking (and we definitely need
666
- * that checking -- e.g., some providers are suitable for RXD or
667
- * RXM, but can't provide local communication).
668
- */
669
- fi_version = FI_VERSION (1 , 5 );
670
-
671
690
/**
672
691
* The EFA provider in Libfabric versions prior to 1.10 contains a bug
673
692
* where the FI_LOCAL_COMM and FI_REMOTE_COMM capabilities are not
@@ -758,6 +777,15 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
758
777
opal_argv_free (exclude_list );
759
778
exclude_list = NULL ;
760
779
780
+ #if OPAL_CUDA_SUPPORT
781
+ if (!(prov -> caps & FI_HMEM )) {
782
+ opal_output_verbose (1 , opal_common_ofi .output ,
783
+ "%s:%d: Libfabric provider does not support CUDA buffers\n" ,
784
+ __FILE__ , __LINE__ );
785
+ goto error ;
786
+ }
787
+ #endif /* OPAL_CUDA_SUPPORT */
788
+
761
789
/**
762
790
* Select the format of the OFI tag
763
791
*/
@@ -1033,6 +1061,10 @@ ompi_mtl_ofi_component_init(bool enable_progress_threads,
1033
1061
*/
1034
1062
ompi_mtl_ofi .any_addr = FI_ADDR_UNSPEC ;
1035
1063
1064
+ #if OPAL_CUDA_SUPPORT
1065
+ mca_common_cuda_stage_one_init ();
1066
+ #endif
1067
+
1036
1068
return & ompi_mtl_ofi .base ;
1037
1069
1038
1070
error :
0 commit comments