|
1 | 1 | /**
|
2 | 2 | Copyright (c) 2021 Mellanox Technologies. All rights reserved.
|
| 3 | + Copyright (c) 2022 NVIDIA Corporation. All rights reserved. |
3 | 4 | $COPYRIGHT$
|
4 | 5 |
|
5 | 6 | Additional copyrights may follow
|
|
24 | 25 |
|
25 | 26 | BEGIN_C_DECLS
|
26 | 27 |
|
27 |
| -#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \ |
28 |
| - UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \ |
29 |
| - UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \ |
30 |
| - UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV) |
| 28 | +#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \ |
| 29 | + UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \ |
| 30 | + UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \ |
| 31 | + UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV | \ |
| 32 | + UCC_COLL_TYPE_GATHER | UCC_COLL_TYPE_GATHERV | \ |
| 33 | + UCC_COLL_TYPE_REDUCE_SCATTER | UCC_COLL_TYPE_REDUCE_SCATTERV | \ |
| 34 | + UCC_COLL_TYPE_SCATTERV) |
31 | 35 |
|
32 |
| -#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce," \ |
33 |
| - "ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce" |
| 36 | +#define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather," \ |
| 37 | + "allgatherv,reduce,gather,gatherv,reduce_scatter_block,"\ |
| 38 | + "reduce_scatter,scatterv," \ |
| 39 | + "ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,"\ |
| 40 | + "iallgatherv,ireduce,igather,igatherv,ireduce_scatter_block,"\ |
| 41 | + "ireduce_scatter,iscatterv" |
34 | 42 |
|
35 | 43 | typedef struct mca_coll_ucc_req {
|
36 | 44 | ompi_request_t super;
|
@@ -64,42 +72,62 @@ OMPI_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
|
64 | 72 | * UCC enabled communicator
|
65 | 73 | */
|
66 | 74 | struct mca_coll_ucc_module_t {
|
67 |
| - mca_coll_base_module_t super; |
68 |
| - ompi_communicator_t* comm; |
69 |
| - int rank; |
70 |
| - ucc_team_h ucc_team; |
71 |
| - mca_coll_base_module_allreduce_fn_t previous_allreduce; |
72 |
| - mca_coll_base_module_t* previous_allreduce_module; |
73 |
| - mca_coll_base_module_iallreduce_fn_t previous_iallreduce; |
74 |
| - mca_coll_base_module_t* previous_iallreduce_module; |
75 |
| - mca_coll_base_module_reduce_fn_t previous_reduce; |
76 |
| - mca_coll_base_module_t* previous_reduce_module; |
77 |
| - mca_coll_base_module_ireduce_fn_t previous_ireduce; |
78 |
| - mca_coll_base_module_t* previous_ireduce_module; |
79 |
| - mca_coll_base_module_barrier_fn_t previous_barrier; |
80 |
| - mca_coll_base_module_t* previous_barrier_module; |
81 |
| - mca_coll_base_module_ibarrier_fn_t previous_ibarrier; |
82 |
| - mca_coll_base_module_t* previous_ibarrier_module; |
83 |
| - mca_coll_base_module_bcast_fn_t previous_bcast; |
84 |
| - mca_coll_base_module_t* previous_bcast_module; |
85 |
| - mca_coll_base_module_ibcast_fn_t previous_ibcast; |
86 |
| - mca_coll_base_module_t* previous_ibcast_module; |
87 |
| - mca_coll_base_module_alltoall_fn_t previous_alltoall; |
88 |
| - mca_coll_base_module_t* previous_alltoall_module; |
89 |
| - mca_coll_base_module_ialltoall_fn_t previous_ialltoall; |
90 |
| - mca_coll_base_module_t* previous_ialltoall_module; |
91 |
| - mca_coll_base_module_alltoallv_fn_t previous_alltoallv; |
92 |
| - mca_coll_base_module_t* previous_alltoallv_module; |
93 |
| - mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv; |
94 |
| - mca_coll_base_module_t* previous_ialltoallv_module; |
95 |
| - mca_coll_base_module_allgather_fn_t previous_allgather; |
96 |
| - mca_coll_base_module_t* previous_allgather_module; |
97 |
| - mca_coll_base_module_iallgather_fn_t previous_iallgather; |
98 |
| - mca_coll_base_module_t* previous_iallgather_module; |
99 |
| - mca_coll_base_module_allgatherv_fn_t previous_allgatherv; |
100 |
| - mca_coll_base_module_t* previous_allgatherv_module; |
101 |
| - mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv; |
102 |
| - mca_coll_base_module_t* previous_iallgatherv_module; |
| 75 | + mca_coll_base_module_t super; |
| 76 | + ompi_communicator_t* comm; |
| 77 | + int rank; |
| 78 | + ucc_team_h ucc_team; |
| 79 | + mca_coll_base_module_allreduce_fn_t previous_allreduce; |
| 80 | + mca_coll_base_module_t* previous_allreduce_module; |
| 81 | + mca_coll_base_module_iallreduce_fn_t previous_iallreduce; |
| 82 | + mca_coll_base_module_t* previous_iallreduce_module; |
| 83 | + mca_coll_base_module_reduce_fn_t previous_reduce; |
| 84 | + mca_coll_base_module_t* previous_reduce_module; |
| 85 | + mca_coll_base_module_ireduce_fn_t previous_ireduce; |
| 86 | + mca_coll_base_module_t* previous_ireduce_module; |
| 87 | + mca_coll_base_module_barrier_fn_t previous_barrier; |
| 88 | + mca_coll_base_module_t* previous_barrier_module; |
| 89 | + mca_coll_base_module_ibarrier_fn_t previous_ibarrier; |
| 90 | + mca_coll_base_module_t* previous_ibarrier_module; |
| 91 | + mca_coll_base_module_bcast_fn_t previous_bcast; |
| 92 | + mca_coll_base_module_t* previous_bcast_module; |
| 93 | + mca_coll_base_module_ibcast_fn_t previous_ibcast; |
| 94 | + mca_coll_base_module_t* previous_ibcast_module; |
| 95 | + mca_coll_base_module_alltoall_fn_t previous_alltoall; |
| 96 | + mca_coll_base_module_t* previous_alltoall_module; |
| 97 | + mca_coll_base_module_ialltoall_fn_t previous_ialltoall; |
| 98 | + mca_coll_base_module_t* previous_ialltoall_module; |
| 99 | + mca_coll_base_module_alltoallv_fn_t previous_alltoallv; |
| 100 | + mca_coll_base_module_t* previous_alltoallv_module; |
| 101 | + mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv; |
| 102 | + mca_coll_base_module_t* previous_ialltoallv_module; |
| 103 | + mca_coll_base_module_allgather_fn_t previous_allgather; |
| 104 | + mca_coll_base_module_t* previous_allgather_module; |
| 105 | + mca_coll_base_module_iallgather_fn_t previous_iallgather; |
| 106 | + mca_coll_base_module_t* previous_iallgather_module; |
| 107 | + mca_coll_base_module_allgatherv_fn_t previous_allgatherv; |
| 108 | + mca_coll_base_module_t* previous_allgatherv_module; |
| 109 | + mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv; |
| 110 | + mca_coll_base_module_t* previous_iallgatherv_module; |
| 111 | + mca_coll_base_module_gather_fn_t previous_gather; |
| 112 | + mca_coll_base_module_t* previous_gather_module; |
| 113 | + mca_coll_base_module_igather_fn_t previous_igather; |
| 114 | + mca_coll_base_module_t* previous_igather_module; |
| 115 | + mca_coll_base_module_gatherv_fn_t previous_gatherv; |
| 116 | + mca_coll_base_module_t* previous_gatherv_module; |
| 117 | + mca_coll_base_module_igatherv_fn_t previous_igatherv; |
| 118 | + mca_coll_base_module_t* previous_igatherv_module; |
| 119 | + mca_coll_base_module_reduce_scatter_block_fn_t previous_reduce_scatter_block; |
| 120 | + mca_coll_base_module_t* previous_reduce_scatter_block_module; |
| 121 | + mca_coll_base_module_ireduce_scatter_block_fn_t previous_ireduce_scatter_block; |
| 122 | + mca_coll_base_module_t* previous_ireduce_scatter_block_module; |
| 123 | + mca_coll_base_module_reduce_scatter_fn_t previous_reduce_scatter; |
| 124 | + mca_coll_base_module_t* previous_reduce_scatter_module; |
| 125 | + mca_coll_base_module_ireduce_scatter_fn_t previous_ireduce_scatter; |
| 126 | + mca_coll_base_module_t* previous_ireduce_scatter_module; |
| 127 | + mca_coll_base_module_scatterv_fn_t previous_scatterv; |
| 128 | + mca_coll_base_module_t* previous_scatterv_module; |
| 129 | + mca_coll_base_module_iscatterv_fn_t previous_iscatterv; |
| 130 | + mca_coll_base_module_t* previous_iscatterv_module; |
103 | 131 | };
|
104 | 132 | typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
|
105 | 133 | OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
|
@@ -195,5 +223,70 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, int scount, struct ompi_datatype_
|
195 | 223 | ompi_request_t** request,
|
196 | 224 | mca_coll_base_module_t *module);
|
197 | 225 |
|
| 226 | +int mca_coll_ucc_gather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, |
| 227 | + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, |
| 228 | + int root, struct ompi_communicator_t *comm, |
| 229 | + mca_coll_base_module_t *module); |
| 230 | + |
| 231 | +int mca_coll_ucc_igather(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, |
| 232 | + void *rbuf, int rcount, struct ompi_datatype_t *rdtype, |
| 233 | + int root, struct ompi_communicator_t *comm, |
| 234 | + ompi_request_t** request, |
| 235 | + mca_coll_base_module_t *module); |
| 236 | + |
| 237 | +int mca_coll_ucc_gatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, |
| 238 | + void *rbuf, const int *rcounts, const int *disps, |
| 239 | + struct ompi_datatype_t *rdtype, int root, |
| 240 | + struct ompi_communicator_t *comm, |
| 241 | + mca_coll_base_module_t *module); |
| 242 | + |
| 243 | +int mca_coll_ucc_igatherv(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, |
| 244 | + void *rbuf, const int *rcounts, const int *disps, |
| 245 | + struct ompi_datatype_t *rdtype, int root, |
| 246 | + struct ompi_communicator_t *comm, |
| 247 | + ompi_request_t** request, |
| 248 | + mca_coll_base_module_t *module); |
| 249 | + |
| 250 | +int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, int rcount, |
| 251 | + struct ompi_datatype_t *dtype, |
| 252 | + struct ompi_op_t *op, |
| 253 | + struct ompi_communicator_t *comm, |
| 254 | + mca_coll_base_module_t *module); |
| 255 | + |
| 256 | +int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, int rcount, |
| 257 | + struct ompi_datatype_t *dtype, |
| 258 | + struct ompi_op_t *op, |
| 259 | + struct ompi_communicator_t *comm, |
| 260 | + ompi_request_t** request, |
| 261 | + mca_coll_base_module_t *module); |
| 262 | + |
| 263 | +int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts, |
| 264 | + struct ompi_datatype_t *dtype, |
| 265 | + struct ompi_op_t *op, |
| 266 | + struct ompi_communicator_t *comm, |
| 267 | + mca_coll_base_module_t *module); |
| 268 | + |
| 269 | +int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, const int *rcounts, |
| 270 | + struct ompi_datatype_t *dtype, |
| 271 | + struct ompi_op_t *op, |
| 272 | + struct ompi_communicator_t *comm, |
| 273 | + ompi_request_t** request, |
| 274 | + mca_coll_base_module_t *module); |
| 275 | + |
| 276 | +int mca_coll_ucc_scatterv(const void *sbuf, const int *scounts, |
| 277 | + const int *disps, struct ompi_datatype_t *sdtype, |
| 278 | + void *rbuf, int rcount, |
| 279 | + struct ompi_datatype_t *rdtype, int root, |
| 280 | + struct ompi_communicator_t *comm, |
| 281 | + mca_coll_base_module_t *module); |
| 282 | + |
| 283 | +int mca_coll_ucc_iscatterv(const void *sbuf, const int *scounts, |
| 284 | + const int *disps, struct ompi_datatype_t *sdtype, |
| 285 | + void *rbuf, int rcount, |
| 286 | + struct ompi_datatype_t *rdtype, int root, |
| 287 | + struct ompi_communicator_t *comm, |
| 288 | + ompi_request_t** request, |
| 289 | + mca_coll_base_module_t *module); |
| 290 | + |
198 | 291 | END_C_DECLS
|
199 | 292 | #endif
|
0 commit comments