42
42
#include "mpi.h"
43
43
#include "ompi/mca/mca.h"
44
44
#include "opal/util/output.h"
45
+ #include "opal/mca/smsc/smsc.h"
45
46
#include "ompi/mca/coll/base/coll_base_functions.h"
46
47
#include "coll_han_trigger.h"
47
48
#include "ompi/mca/coll/han/coll_han_dynamic.h"
@@ -197,6 +198,7 @@ typedef struct mca_coll_han_op_module_name_t {
197
198
mca_coll_han_op_up_low_module_name_t gatherv ;
198
199
mca_coll_han_op_up_low_module_name_t scatter ;
199
200
mca_coll_han_op_up_low_module_name_t scatterv ;
201
+ mca_coll_han_op_up_low_module_name_t alltoall ;
200
202
} mca_coll_han_op_module_name_t ;
201
203
202
204
/**
@@ -252,6 +254,13 @@ typedef struct mca_coll_han_component_t {
252
254
uint32_t han_scatterv_up_module ;
253
255
/* low level module for scatterv */
254
256
uint32_t han_scatterv_low_module ;
257
+
258
+ /* low level module for alltoall */
259
+ uint32_t han_alltoall_low_module ;
260
+ /* alltoall: parallel stages */
261
+ int32_t han_alltoall_pstages ;
262
+
263
+
255
264
/* name of the modules */
256
265
mca_coll_han_op_module_name_t han_op_module_name ;
257
266
/* whether we need reproducible results
@@ -287,6 +296,7 @@ typedef struct mca_coll_han_single_collective_fallback_s
287
296
{
288
297
union
289
298
{
299
+ mca_coll_base_module_alltoall_fn_t alltoall ;
290
300
mca_coll_base_module_allgather_fn_t allgather ;
291
301
mca_coll_base_module_allgatherv_fn_t allgatherv ;
292
302
mca_coll_base_module_allreduce_fn_t allreduce ;
@@ -308,6 +318,7 @@ typedef struct mca_coll_han_single_collective_fallback_s
308
318
*/
309
319
typedef struct mca_coll_han_collectives_fallback_s
310
320
{
321
+ mca_coll_han_single_collective_fallback_t alltoall ;
311
322
mca_coll_han_single_collective_fallback_t allgather ;
312
323
mca_coll_han_single_collective_fallback_t allgatherv ;
313
324
mca_coll_han_single_collective_fallback_t allreduce ;
@@ -370,6 +381,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
370
381
* Some defines to stick to the naming used in the other components in terms of
371
382
* fallback routines
372
383
*/
384
+ #define previous_alltoall fallback.alltoall.alltoall
385
+ #define previous_alltoall_module fallback.alltoall.module
386
+
373
387
#define previous_allgather fallback.allgather.allgather
374
388
#define previous_allgather_module fallback.allgather.module
375
389
@@ -425,6 +439,7 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
425
439
HAN_UNINSTALL_COLL_API(COMM, HANM, allreduce); \
426
440
HAN_UNINSTALL_COLL_API(COMM, HANM, allgather); \
427
441
HAN_UNINSTALL_COLL_API(COMM, HANM, allgatherv); \
442
+ HAN_UNINSTALL_COLL_API(COMM, HANM, alltoall); \
428
443
han_module->enabled = false; /* entire module set to pass-through from now on */ \
429
444
} while (0 )
430
445
@@ -485,6 +500,9 @@ mca_coll_han_get_all_coll_modules(struct ompi_communicator_t *comm,
485
500
mca_coll_han_module_t * han_module );
486
501
487
502
int
503
+ mca_coll_han_alltoall_intra_dynamic (ALLTOALL_BASE_ARGS ,
504
+ mca_coll_base_module_t * module );
505
+ int
488
506
mca_coll_han_allgather_intra_dynamic (ALLGATHER_BASE_ARGS ,
489
507
mca_coll_base_module_t * module );
490
508
int
@@ -532,4 +550,20 @@ coll_han_utils_gcd(const uint64_t *numerators, const size_t size);
532
550
int
533
551
coll_han_utils_create_contiguous_datatype (size_t count , const ompi_datatype_t * oldType ,
534
552
ompi_datatype_t * * newType );
553
+
554
+ static inline struct mca_smsc_endpoint_t * mca_coll_han_get_smsc_endpoint (struct ompi_proc_t * proc ) {
555
+ extern opal_mutex_t mca_coll_han_lock ;
556
+ if (NULL == proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ]) {
557
+ if (NULL == proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ]) {
558
+ OPAL_THREAD_LOCK (& mca_coll_han_lock );
559
+ if (NULL == proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ]) {
560
+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ] = mca_smsc -> get_endpoint (& proc -> super );
561
+ }
562
+ OPAL_THREAD_UNLOCK (& mca_coll_han_lock );
563
+ }
564
+ }
565
+
566
+ return (struct mca_smsc_endpoint_t * ) proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ];
567
+ }
568
+
535
569
#endif /* MCA_COLL_HAN_EXPORT_H */
0 commit comments