Skip to content

Commit c74d333

Browse files
authored
Merge pull request #11118 from MamziB/mamzi/nonblocking-acc-dt
OSC/UCX: NB Accumulate with DT and fix for issue/11114
2 parents fa7face + 9dfa35f commit c74d333

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,16 @@ static inline int get_dynamic_win_info(uint64_t remote_addr,
339339

340340
if (mem_rec->rkeys[target] != NULL) {
341341
ucp_rkey_destroy(mem_rec->rkeys[target]);
342+
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_unpacked_rkey_counts, -1);
342343
}
343344

344345
void *rkey_buffer = &temp_dynamic_wins[contain].mem_addr;
345346

346347
ret = ucp_ep_rkey_unpack(mem_rec->winfo->endpoints[target], rkey_buffer,
347348
&mem_rec->rkeys[target]);
348349

350+
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_unpacked_rkey_counts, 1);
351+
349352
opal_mutex_unlock(&mem_rec->winfo->mutex);
350353

351354
if (ret != UCS_OK) {
@@ -929,8 +932,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
929932
struct ompi_datatype_t *origin_dt, void *result_addr, int result_count,
930933
struct ompi_datatype_t *result_dt, int target, ptrdiff_t target_disp,
931934
int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t
932-
*op, struct ompi_win_t *win) {
933-
935+
*op, struct ompi_win_t *win, int acc_type) {
934936
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
935937
int ret = OMPI_SUCCESS;
936938
uint64_t remote_addr = (module->addrs[target]) + target_disp *
@@ -943,7 +945,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
943945
return ret;
944946
}
945947

946-
if (result_addr == NULL && op == &ompi_mpi_op_no_op.op) {
948+
if (ACCUMULATE == acc_type && op == &ompi_mpi_op_no_op.op) {
947949
/* This is an accumulate (not get-accumulate) operation, so return */
948950
return ret;
949951
}
@@ -971,7 +973,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
971973

972974
CHECK_DYNAMIC_WIN(remote_addr, module, target, ret);
973975

974-
if (result_addr != NULL) {
976+
if (GET_ACCUMULATE == acc_type) {
975977
/* This is a get-accumulate operation, so read the target data into result addr */
976978
ret = ompi_osc_ucx_acc_rputget(result_addr, (int)result_count, result_dt, target,
977979
target_disp, target_count, target_dt, op, win, lock_acquired,
@@ -985,7 +987,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
985987
}
986988

987989
if (op == &ompi_mpi_op_replace.op) {
988-
assert(result_addr == NULL);
990+
assert(ACCUMULATE == acc_type);
989991
/* No need for get, just use put and realize when to release the lock */
990992
ret = ompi_osc_ucx_acc_rputget(NULL, 0, NULL, target, target_disp,
991993
target_count, target_dt, op, win, lock_acquired, origin_addr,
@@ -1018,7 +1020,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
10181020
ret = ompi_osc_ucx_acc_rputget(temp_addr, (int)temp_count, temp_dt, target,
10191021
target_disp, target_count, target_dt, op, win, lock_acquired,
10201022
origin_addr, origin_count, origin_dt, false, ACC_GET_STAGE_DATA,
1021-
(result_addr == NULL) ? ACCUMULATE : GET_ACCUMULATE);
1023+
acc_type);
10221024
if (ret != OMPI_SUCCESS) {
10231025
return ret;
10241026
}
@@ -1035,7 +1037,7 @@ int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
10351037

10361038
return ompi_osc_ucx_get_accumulate_nonblocking(origin_addr, origin_count,
10371039
origin_dt, (void *)NULL, 0, NULL, target, target_disp,
1038-
target_count, target_dt, op, win);
1040+
target_count, target_dt, op, win, ACCUMULATE);
10391041
}
10401042

10411043
static int
@@ -1372,7 +1374,7 @@ int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
13721374

13731375
return ompi_osc_ucx_get_accumulate_nonblocking(origin_addr, origin_count, origin_dt,
13741376
result_addr, result_count, result_dt, target, target_disp,
1375-
target_count, target_dt, op, win);
1377+
target_count, target_dt, op, win, GET_ACCUMULATE);
13761378
}
13771379

13781380
int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,

opal/mca/common/ucx/common_ucx_wpool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ extern opal_atomic_int64_t opal_common_ucx_unpacked_rkey_counts;
6868
opal_atomic_add_fetch_64(&(_var), (_val)); \
6969
} while(0);
7070
#else
71-
#define OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(&(_var), (_val));
71+
#define OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(_var, _val);
7272
#endif
7373

7474
/* Worker Pool Context (wpctx) is an object that is comprised of a set of UCP

0 commit comments

Comments
 (0)