@@ -339,13 +339,16 @@ static inline int get_dynamic_win_info(uint64_t remote_addr,
339
339
340
340
if (mem_rec -> rkeys [target ] != NULL ) {
341
341
ucp_rkey_destroy (mem_rec -> rkeys [target ]);
342
+ OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD (opal_common_ucx_unpacked_rkey_counts , -1 );
342
343
}
343
344
344
345
void * rkey_buffer = & temp_dynamic_wins [contain ].mem_addr ;
345
346
346
347
ret = ucp_ep_rkey_unpack (mem_rec -> winfo -> endpoints [target ], rkey_buffer ,
347
348
& mem_rec -> rkeys [target ]);
348
349
350
+ OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD (opal_common_ucx_unpacked_rkey_counts , 1 );
351
+
349
352
opal_mutex_unlock (& mem_rec -> winfo -> mutex );
350
353
351
354
if (ret != UCS_OK ) {
@@ -929,8 +932,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
929
932
struct ompi_datatype_t * origin_dt , void * result_addr , int result_count ,
930
933
struct ompi_datatype_t * result_dt , int target , ptrdiff_t target_disp ,
931
934
int target_count , struct ompi_datatype_t * target_dt , struct ompi_op_t
932
- * op , struct ompi_win_t * win ) {
933
-
935
+ * op , struct ompi_win_t * win , int acc_type ) {
934
936
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
935
937
int ret = OMPI_SUCCESS ;
936
938
uint64_t remote_addr = (module -> addrs [target ]) + target_disp *
@@ -943,7 +945,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
943
945
return ret ;
944
946
}
945
947
946
- if (result_addr == NULL && op == & ompi_mpi_op_no_op .op ) {
948
+ if (ACCUMULATE == acc_type && op == & ompi_mpi_op_no_op .op ) {
947
949
/* This is an accumulate (not get-accumulate) operation, so return */
948
950
return ret ;
949
951
}
@@ -971,7 +973,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
971
973
972
974
CHECK_DYNAMIC_WIN (remote_addr , module , target , ret );
973
975
974
- if (result_addr != NULL ) {
976
+ if (GET_ACCUMULATE == acc_type ) {
975
977
/* This is a get-accumulate operation, so read the target data into result addr */
976
978
ret = ompi_osc_ucx_acc_rputget (result_addr , (int )result_count , result_dt , target ,
977
979
target_disp , target_count , target_dt , op , win , lock_acquired ,
@@ -985,7 +987,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
985
987
}
986
988
987
989
if (op == & ompi_mpi_op_replace .op ) {
988
- assert (result_addr == NULL );
990
+ assert (ACCUMULATE == acc_type );
989
991
/* No need for get, just use put and realize when to release the lock */
990
992
ret = ompi_osc_ucx_acc_rputget (NULL , 0 , NULL , target , target_disp ,
991
993
target_count , target_dt , op , win , lock_acquired , origin_addr ,
@@ -1018,7 +1020,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
1018
1020
ret = ompi_osc_ucx_acc_rputget (temp_addr , (int )temp_count , temp_dt , target ,
1019
1021
target_disp , target_count , target_dt , op , win , lock_acquired ,
1020
1022
origin_addr , origin_count , origin_dt , false, ACC_GET_STAGE_DATA ,
1021
- ( result_addr == NULL ) ? ACCUMULATE : GET_ACCUMULATE );
1023
+ acc_type );
1022
1024
if (ret != OMPI_SUCCESS ) {
1023
1025
return ret ;
1024
1026
}
@@ -1035,7 +1037,7 @@ int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
1035
1037
1036
1038
return ompi_osc_ucx_get_accumulate_nonblocking (origin_addr , origin_count ,
1037
1039
origin_dt , (void * )NULL , 0 , NULL , target , target_disp ,
1038
- target_count , target_dt , op , win );
1040
+ target_count , target_dt , op , win , ACCUMULATE );
1039
1041
}
1040
1042
1041
1043
static int
@@ -1372,7 +1374,7 @@ int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
1372
1374
1373
1375
return ompi_osc_ucx_get_accumulate_nonblocking (origin_addr , origin_count , origin_dt ,
1374
1376
result_addr , result_count , result_dt , target , target_disp ,
1375
- target_count , target_dt , op , win );
1377
+ target_count , target_dt , op , win , GET_ACCUMULATE );
1376
1378
}
1377
1379
1378
1380
int ompi_osc_ucx_rput (const void * origin_addr , int origin_count ,
0 commit comments