@@ -929,8 +929,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
929
929
struct ompi_datatype_t * origin_dt , void * result_addr , int result_count ,
930
930
struct ompi_datatype_t * result_dt , int target , ptrdiff_t target_disp ,
931
931
int target_count , struct ompi_datatype_t * target_dt , struct ompi_op_t
932
- * op , struct ompi_win_t * win ) {
933
-
932
+ * op , struct ompi_win_t * win , int acc_type ) {
934
933
ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
935
934
int ret = OMPI_SUCCESS ;
936
935
uint64_t remote_addr = (module -> addrs [target ]) + target_disp *
@@ -943,7 +942,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
943
942
return ret ;
944
943
}
945
944
946
- if (result_addr == NULL && op == & ompi_mpi_op_no_op .op ) {
945
+ if (ACCUMULATE == acc_type && op == & ompi_mpi_op_no_op .op ) {
947
946
/* This is an accumulate (not get-accumulate) operation, so return */
948
947
return ret ;
949
948
}
@@ -971,7 +970,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
971
970
972
971
CHECK_DYNAMIC_WIN (remote_addr , module , target , ret );
973
972
974
- if (result_addr != NULL ) {
973
+ if (GET_ACCUMULATE == acc_type ) {
975
974
/* This is a get-accumulate operation, so read the target data into result addr */
976
975
ret = ompi_osc_ucx_acc_rputget (result_addr , (int )result_count , result_dt , target ,
977
976
target_disp , target_count , target_dt , op , win , lock_acquired ,
@@ -985,7 +984,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
985
984
}
986
985
987
986
if (op == & ompi_mpi_op_replace .op ) {
988
- assert (result_addr == NULL );
987
+ assert (ACCUMULATE == acc_type );
989
988
/* No need for get, just use put and realize when to release the lock */
990
989
ret = ompi_osc_ucx_acc_rputget (NULL , 0 , NULL , target , target_disp ,
991
990
target_count , target_dt , op , win , lock_acquired , origin_addr ,
@@ -1018,7 +1017,7 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
1018
1017
ret = ompi_osc_ucx_acc_rputget (temp_addr , (int )temp_count , temp_dt , target ,
1019
1018
target_disp , target_count , target_dt , op , win , lock_acquired ,
1020
1019
origin_addr , origin_count , origin_dt , false, ACC_GET_STAGE_DATA ,
1021
- ( result_addr == NULL ) ? ACCUMULATE : GET_ACCUMULATE );
1020
+ acc_type );
1022
1021
if (ret != OMPI_SUCCESS ) {
1023
1022
return ret ;
1024
1023
}
@@ -1035,7 +1034,7 @@ int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
1035
1034
1036
1035
return ompi_osc_ucx_get_accumulate_nonblocking (origin_addr , origin_count ,
1037
1036
origin_dt , (void * )NULL , 0 , NULL , target , target_disp ,
1038
- target_count , target_dt , op , win );
1037
+ target_count , target_dt , op , win , ACCUMULATE );
1039
1038
}
1040
1039
1041
1040
static int
@@ -1372,7 +1371,7 @@ int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
1372
1371
1373
1372
return ompi_osc_ucx_get_accumulate_nonblocking (origin_addr , origin_count , origin_dt ,
1374
1373
result_addr , result_count , result_dt , target , target_disp ,
1375
- target_count , target_dt , op , win );
1374
+ target_count , target_dt , op , win , GET_ACCUMULATE );
1376
1375
}
1377
1376
1378
1377
int ompi_osc_ucx_rput (const void * origin_addr , int origin_count ,
0 commit comments