@@ -26,10 +26,11 @@ BEGIN_C_DECLS
26
26
27
27
#define COLL_UCC_CTS (UCC_COLL_TYPE_BARRIER | UCC_COLL_TYPE_BCAST | \
28
28
UCC_COLL_TYPE_ALLREDUCE | UCC_COLL_TYPE_ALLTOALL | \
29
- UCC_COLL_TYPE_ALLTOALLV)
29
+ UCC_COLL_TYPE_ALLTOALLV | UCC_COLL_TYPE_ALLGATHER | \
30
+ UCC_COLL_TYPE_REDUCE | UCC_COLL_TYPE_ALLGATHERV)
30
31
31
- #define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv," \
32
- "ibarrier,ibcast,iallreduce,ialltoall,ialltoallv"
32
+ #define COLL_UCC_CTS_STR "barrier,bcast,allreduce,alltoall,alltoallv,allgather,allgatherv,reduce, " \
33
+ "ibarrier,ibcast,iallreduce,ialltoall,ialltoallv,iallgather,iallgatherv,ireduce "
33
34
34
35
typedef struct mca_coll_ucc_req {
35
36
ompi_request_t super ;
@@ -63,30 +64,42 @@ OMPI_MODULE_DECLSPEC extern mca_coll_ucc_component_t mca_coll_ucc_component;
63
64
* UCC enabled communicator
64
65
*/
65
66
struct mca_coll_ucc_module_t {
66
- mca_coll_base_module_t super ;
67
- ompi_communicator_t * comm ;
68
- int rank ;
69
- ucc_team_h ucc_team ;
70
- mca_coll_base_module_allreduce_fn_t previous_allreduce ;
71
- mca_coll_base_module_t * previous_allreduce_module ;
72
- mca_coll_base_module_iallreduce_fn_t previous_iallreduce ;
73
- mca_coll_base_module_t * previous_iallreduce_module ;
74
- mca_coll_base_module_barrier_fn_t previous_barrier ;
75
- mca_coll_base_module_t * previous_barrier_module ;
76
- mca_coll_base_module_ibarrier_fn_t previous_ibarrier ;
77
- mca_coll_base_module_t * previous_ibarrier_module ;
78
- mca_coll_base_module_bcast_fn_t previous_bcast ;
79
- mca_coll_base_module_t * previous_bcast_module ;
80
- mca_coll_base_module_ibcast_fn_t previous_ibcast ;
81
- mca_coll_base_module_t * previous_ibcast_module ;
82
- mca_coll_base_module_alltoall_fn_t previous_alltoall ;
83
- mca_coll_base_module_t * previous_alltoall_module ;
84
- mca_coll_base_module_ialltoall_fn_t previous_ialltoall ;
85
- mca_coll_base_module_t * previous_ialltoall_module ;
86
- mca_coll_base_module_alltoallv_fn_t previous_alltoallv ;
87
- mca_coll_base_module_t * previous_alltoallv_module ;
88
- mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv ;
89
- mca_coll_base_module_t * previous_ialltoallv_module ;
67
+ mca_coll_base_module_t super ;
68
+ ompi_communicator_t * comm ;
69
+ int rank ;
70
+ ucc_team_h ucc_team ;
71
+ mca_coll_base_module_allreduce_fn_t previous_allreduce ;
72
+ mca_coll_base_module_t * previous_allreduce_module ;
73
+ mca_coll_base_module_iallreduce_fn_t previous_iallreduce ;
74
+ mca_coll_base_module_t * previous_iallreduce_module ;
75
+ mca_coll_base_module_reduce_fn_t previous_reduce ;
76
+ mca_coll_base_module_t * previous_reduce_module ;
77
+ mca_coll_base_module_ireduce_fn_t previous_ireduce ;
78
+ mca_coll_base_module_t * previous_ireduce_module ;
79
+ mca_coll_base_module_barrier_fn_t previous_barrier ;
80
+ mca_coll_base_module_t * previous_barrier_module ;
81
+ mca_coll_base_module_ibarrier_fn_t previous_ibarrier ;
82
+ mca_coll_base_module_t * previous_ibarrier_module ;
83
+ mca_coll_base_module_bcast_fn_t previous_bcast ;
84
+ mca_coll_base_module_t * previous_bcast_module ;
85
+ mca_coll_base_module_ibcast_fn_t previous_ibcast ;
86
+ mca_coll_base_module_t * previous_ibcast_module ;
87
+ mca_coll_base_module_alltoall_fn_t previous_alltoall ;
88
+ mca_coll_base_module_t * previous_alltoall_module ;
89
+ mca_coll_base_module_ialltoall_fn_t previous_ialltoall ;
90
+ mca_coll_base_module_t * previous_ialltoall_module ;
91
+ mca_coll_base_module_alltoallv_fn_t previous_alltoallv ;
92
+ mca_coll_base_module_t * previous_alltoallv_module ;
93
+ mca_coll_base_module_ialltoallv_fn_t previous_ialltoallv ;
94
+ mca_coll_base_module_t * previous_ialltoallv_module ;
95
+ mca_coll_base_module_allgather_fn_t previous_allgather ;
96
+ mca_coll_base_module_t * previous_allgather_module ;
97
+ mca_coll_base_module_iallgather_fn_t previous_iallgather ;
98
+ mca_coll_base_module_t * previous_iallgather_module ;
99
+ mca_coll_base_module_allgatherv_fn_t previous_allgatherv ;
100
+ mca_coll_base_module_t * previous_allgatherv_module ;
101
+ mca_coll_base_module_iallgatherv_fn_t previous_iallgatherv ;
102
+ mca_coll_base_module_t * previous_iallgatherv_module ;
90
103
};
91
104
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t ;
92
105
OBJ_CLASS_DECLARATION (mca_coll_ucc_module_t );
@@ -105,6 +118,17 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, int count,
105
118
ompi_request_t * * request ,
106
119
mca_coll_base_module_t * module );
107
120
121
+ int mca_coll_ucc_reduce (const void * sbuf , void * rbuf , int count ,
122
+ struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
123
+ int root , struct ompi_communicator_t * comm ,
124
+ struct mca_coll_base_module_2_4_0_t * module );
125
+
126
+ int mca_coll_ucc_ireduce (const void * sbuf , void * rbuf , int count ,
127
+ struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
128
+ int root , struct ompi_communicator_t * comm ,
129
+ ompi_request_t * * request ,
130
+ struct mca_coll_base_module_2_4_0_t * module );
131
+
108
132
int mca_coll_ucc_barrier (struct ompi_communicator_t * comm ,
109
133
mca_coll_base_module_t * module );
110
134
@@ -146,5 +170,30 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, const int *scounts, const int *sdi
146
170
struct ompi_communicator_t * comm ,
147
171
ompi_request_t * * request ,
148
172
mca_coll_base_module_t * module );
173
+
174
+ int mca_coll_ucc_allgather (const void * sbuf , int scount , struct ompi_datatype_t * sdtype ,
175
+ void * rbuf , int rcount , struct ompi_datatype_t * rdtype ,
176
+ struct ompi_communicator_t * comm ,
177
+ mca_coll_base_module_t * module );
178
+
179
+ int mca_coll_ucc_iallgather (const void * sbuf , int scount , struct ompi_datatype_t * sdtype ,
180
+ void * rbuf , int rcount , struct ompi_datatype_t * rdtype ,
181
+ struct ompi_communicator_t * comm ,
182
+ ompi_request_t * * request ,
183
+ mca_coll_base_module_t * module );
184
+
185
+ int mca_coll_ucc_allgatherv (const void * sbuf , int scount , struct ompi_datatype_t * sdtype ,
186
+ void * rbuf , const int * rcounts , const int * rdisps ,
187
+ struct ompi_datatype_t * rdtype ,
188
+ struct ompi_communicator_t * comm ,
189
+ mca_coll_base_module_t * module );
190
+
191
+ int mca_coll_ucc_iallgatherv (const void * sbuf , int scount , struct ompi_datatype_t * sdtype ,
192
+ void * rbuf , const int * rcounts , const int * rdisps ,
193
+ struct ompi_datatype_t * rdtype ,
194
+ struct ompi_communicator_t * comm ,
195
+ ompi_request_t * * request ,
196
+ mca_coll_base_module_t * module );
197
+
149
198
END_C_DECLS
150
199
#endif
0 commit comments