12
12
13
13
#include <ucc/api/ucc.h>
14
14
15
- static inline ucc_status_t mca_scoll_ucc_alltoall_init (const void * sbuf , void * rbuf ,
16
- int count ,
17
- mca_scoll_ucc_module_t * ucc_module ,
18
- ucc_coll_req_h * req )
15
+ static inline ucc_status_t mca_scoll_ucc_alltoall_init (const void * sbuf , void * rbuf ,
16
+ int count , size_t elem_size ,
17
+ mca_scoll_ucc_module_t * ucc_module ,
18
+ ucc_coll_req_h * req )
19
19
{
20
+ ucc_datatype_t dt ;
21
+
22
+ if (elem_size == 8 ) {
23
+ dt = UCC_DT_INT64 ;
24
+ } else if (elem_size == 4 ) {
25
+ dt = UCC_DT_INT32 ;
26
+ } else {
27
+ dt = UCC_DT_INT8 ;
28
+ }
29
+
20
30
ucc_coll_args_t coll = {
21
31
.mask = 0 ,
22
32
.coll_type = UCC_COLL_TYPE_ALLTOALL ,
23
33
.src .info = {
24
34
.buffer = (void * )sbuf ,
25
- .count = count ,
26
- .datatype = UCC_DT_UINT8 ,
35
+ .count = count * ucc_module -> group -> proc_count ,
36
+ .datatype = dt ,
27
37
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
28
38
},
29
39
.dst .info = {
30
40
.buffer = rbuf ,
31
- .count = count ,
32
- .datatype = UCC_DT_UINT8 ,
41
+ .count = count * ucc_module -> group -> proc_count ,
42
+ .datatype = dt ,
33
43
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
34
44
},
35
45
};
@@ -56,14 +66,15 @@ int mca_scoll_ucc_alltoall(struct oshmem_group_t *group,
56
66
57
67
UCC_VERBOSE (3 , "running ucc alltoall" );
58
68
ucc_module = (mca_scoll_ucc_module_t * ) group -> g_scoll .scoll_alltoall_module ;
59
- count = nelems * element_size ;
69
+ count = nelems ;
60
70
61
71
/* Do nothing on zero-length request */
62
72
if (OPAL_UNLIKELY (!nelems )) {
63
73
return OSHMEM_SUCCESS ;
64
74
}
65
75
66
- SCOLL_UCC_CHECK (mca_scoll_ucc_alltoall_init (source , target , count , ucc_module , & req ));
76
+ SCOLL_UCC_CHECK (mca_scoll_ucc_alltoall_init (source , target , count ,
77
+ element_size , ucc_module , & req ));
67
78
SCOLL_UCC_CHECK (ucc_collective_post (req ));
68
79
SCOLL_UCC_CHECK (scoll_ucc_req_wait (req ));
69
80
return OSHMEM_SUCCESS ;
0 commit comments