Skip to content

Commit bc31f5f

Browse files
Mamzi Bayatpour  mbayatpour@nvidia.com ()janjust
authored andcommitted
OSC/UCX: Allow nonblocking get_accumulate to be called with results_addr
equal to NULL (can happpen with noncontigous dt) Signed-off-by: Mamzi Bayatpour <mbayatpour@nvidia.com> Co-authored-by: Tomislav Janjusic <tomislavj@nvidia.com>
1 parent df4235a commit bc31f5f

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
929929
struct ompi_datatype_t *origin_dt, void *result_addr, int result_count,
930930
struct ompi_datatype_t *result_dt, int target, ptrdiff_t target_disp,
931931
int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t
932-
*op, struct ompi_win_t *win) {
933-
932+
*op, struct ompi_win_t *win, int acc_type) {
934933
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
935934
int ret = OMPI_SUCCESS;
936935
uint64_t remote_addr = (module->addrs[target]) + target_disp *
@@ -943,7 +942,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
943942
return ret;
944943
}
945944

946-
if (result_addr == NULL && op == &ompi_mpi_op_no_op.op) {
945+
if (ACCUMULATE == acc_type && op == &ompi_mpi_op_no_op.op) {
947946
/* This is an accumulate (not get-accumulate) operation, so return */
948947
return ret;
949948
}
@@ -971,7 +970,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
971970

972971
CHECK_DYNAMIC_WIN(remote_addr, module, target, ret);
973972

974-
if (result_addr != NULL) {
973+
if (GET_ACCUMULATE == acc_type) {
975974
/* This is a get-accumulate operation, so read the target data into result addr */
976975
ret = ompi_osc_ucx_acc_rputget(result_addr, (int)result_count, result_dt, target,
977976
target_disp, target_count, target_dt, op, win, lock_acquired,
@@ -985,7 +984,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
985984
}
986985

987986
if (op == &ompi_mpi_op_replace.op) {
988-
assert(result_addr == NULL);
987+
assert(ACCUMULATE == acc_type);
989988
/* No need for get, just use put and realize when to release the lock */
990989
ret = ompi_osc_ucx_acc_rputget(NULL, 0, NULL, target, target_disp,
991990
target_count, target_dt, op, win, lock_acquired, origin_addr,
@@ -1018,7 +1017,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
10181017
ret = ompi_osc_ucx_acc_rputget(temp_addr, (int)temp_count, temp_dt, target,
10191018
target_disp, target_count, target_dt, op, win, lock_acquired,
10201019
origin_addr, origin_count, origin_dt, false, ACC_GET_STAGE_DATA,
1021-
(result_addr == NULL) ? ACCUMULATE : GET_ACCUMULATE);
1020+
acc_type);
10221021
if (ret != OMPI_SUCCESS) {
10231022
return ret;
10241023
}
@@ -1035,7 +1034,7 @@ int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
10351034

10361035
return ompi_osc_ucx_get_accumulate_nonblocking(origin_addr, origin_count,
10371036
origin_dt, (void *)NULL, 0, NULL, target, target_disp,
1038-
target_count, target_dt, op, win);
1037+
target_count, target_dt, op, win, ACCUMULATE);
10391038
}
10401039

10411040
static int
@@ -1372,7 +1371,7 @@ int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
13721371

13731372
return ompi_osc_ucx_get_accumulate_nonblocking(origin_addr, origin_count, origin_dt,
13741373
result_addr, result_count, result_dt, target, target_disp,
1375-
target_count, target_dt, op, win);
1374+
target_count, target_dt, op, win, GET_ACCUMULATE);
13761375
}
13771376

13781377
int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,

0 commit comments

Comments
 (0)