Skip to content

Commit 0a8942b

Browse files
authored
Merge pull request #9488 from wfaderhold21/topic/scoll_ucc_a2a
scoll/ucc: adjust count/dtype for alltoall
2 parents ff6b034 + 3d96daf commit 0a8942b

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

oshmem/mca/scoll/ucc/scoll_ucc_alltoall.c

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,34 @@
1212

1313
#include <ucc/api/ucc.h>
1414

15-
static inline ucc_status_t mca_scoll_ucc_alltoall_init(const void *sbuf, void *rbuf,
16-
int count,
17-
mca_scoll_ucc_module_t * ucc_module,
18-
ucc_coll_req_h * req)
15+
static inline ucc_status_t mca_scoll_ucc_alltoall_init(const void *sbuf, void *rbuf,
16+
int count, size_t elem_size,
17+
mca_scoll_ucc_module_t *ucc_module,
18+
ucc_coll_req_h *req)
1919
{
20+
ucc_datatype_t dt;
21+
22+
if (elem_size == 8) {
23+
dt = UCC_DT_INT64;
24+
} else if (elem_size == 4) {
25+
dt = UCC_DT_INT32;
26+
} else {
27+
dt = UCC_DT_INT8;
28+
}
29+
2030
ucc_coll_args_t coll = {
2131
.mask = 0,
2232
.coll_type = UCC_COLL_TYPE_ALLTOALL,
2333
.src.info = {
2434
.buffer = (void *)sbuf,
25-
.count = count,
26-
.datatype = UCC_DT_UINT8,
35+
.count = count * ucc_module->group->proc_count,
36+
.datatype = dt,
2737
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
2838
},
2939
.dst.info = {
3040
.buffer = rbuf,
31-
.count = count,
32-
.datatype = UCC_DT_UINT8,
41+
.count = count * ucc_module->group->proc_count,
42+
.datatype = dt,
3343
.mem_type = UCC_MEMORY_TYPE_UNKNOWN
3444
},
3545
};
@@ -56,14 +66,15 @@ int mca_scoll_ucc_alltoall(struct oshmem_group_t *group,
5666

5767
UCC_VERBOSE(3, "running ucc alltoall");
5868
ucc_module = (mca_scoll_ucc_module_t *) group->g_scoll.scoll_alltoall_module;
59-
count = nelems * element_size;
69+
count = nelems;
6070

6171
/* Do nothing on zero-length request */
6272
if (OPAL_UNLIKELY(!nelems)) {
6373
return OSHMEM_SUCCESS;
6474
}
6575

66-
SCOLL_UCC_CHECK(mca_scoll_ucc_alltoall_init(source, target, count, ucc_module, &req));
76+
SCOLL_UCC_CHECK(mca_scoll_ucc_alltoall_init(source, target, count,
77+
element_size, ucc_module, &req));
6778
SCOLL_UCC_CHECK(ucc_collective_post(req));
6879
SCOLL_UCC_CHECK(scoll_ucc_req_wait(req));
6980
return OSHMEM_SUCCESS;

0 commit comments

Comments
 (0)