Skip to content

Commit c05c23b

Browse files
authored
Merge pull request #10412 from MamziB/mamzi/get-accum
OSC/UCX: Fix data validation issue in get accumulate and intrinsic atomic ops
2 parents bab0bd7 + 0031008 commit c05c23b

File tree

1 file changed

+38
-20
lines changed

1 file changed

+38
-20
lines changed

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,35 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
378378
return ret;
379379
}
380380

381+
static inline
382+
bool osc_is_atomic_dt_op_supported(
383+
struct ompi_datatype_t *dt,
384+
struct ompi_op_t *op,
385+
size_t dt_bytes,
386+
uint64_t remote_addr)
387+
{
388+
/* UCX atomics are only supported on 32 and 64 bit values */
389+
if (!ompi_datatype_is_predefined(dt) ||
390+
!ompi_osc_base_is_atomic_size_supported(remote_addr, dt_bytes)) {
391+
return false;
392+
}
393+
/* Hardware-based atomic add for floating point is not supported */
394+
else if ((
395+
op == &ompi_mpi_op_no_op.op
396+
|| op == &ompi_mpi_op_replace.op
397+
|| op == &ompi_mpi_op_sum.op
398+
)
399+
&& !(
400+
op == &ompi_mpi_op_sum.op
401+
&& (dt == MPI_FLOAT || dt == MPI_DOUBLE
402+
|| dt == MPI_LONG_DOUBLE || dt == MPI_FLOAT_INT)
403+
)) {
404+
return true;
405+
}
406+
407+
return false;
408+
}
409+
381410
static inline
382411
bool use_atomic_op(
383412
ompi_osc_ucx_module_t *module,
@@ -388,25 +417,16 @@ bool use_atomic_op(
388417
int origin_count,
389418
int target_count)
390419
{
420+
size_t origin_dt_bytes;
391421

392-
if (module->acc_single_intrinsic &&
393-
ompi_datatype_is_predefined(origin_dt) &&
394-
origin_count == 1 &&
395-
(op == &ompi_mpi_op_replace.op ||
396-
op == &ompi_mpi_op_sum.op ||
397-
op == &ompi_mpi_op_no_op.op)) {
398-
size_t origin_dt_bytes;
399-
size_t target_dt_bytes;
422+
if (!module->acc_single_intrinsic || origin_count != 1 || target_count != 1
423+
|| origin_dt != target_dt) {
424+
return false;
425+
} else {
400426
ompi_datatype_type_size(origin_dt, &origin_dt_bytes);
401-
ompi_datatype_type_size(target_dt, &target_dt_bytes);
402-
/* UCX only supports 32 and 64-bit operands atm */
403-
if (ompi_osc_base_is_atomic_size_supported(remote_addr, origin_dt_bytes) &&
404-
origin_dt_bytes == target_dt_bytes && origin_count == target_count) {
405-
return true;
406-
}
427+
return osc_is_atomic_dt_op_supported(origin_dt, op, origin_dt_bytes,
428+
remote_addr);
407429
}
408-
409-
return false;
410430
}
411431

412432
static int do_atomic_op_intrinsic(
@@ -859,10 +879,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
859879
uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target);
860880
ompi_datatype_type_size(dt, &dt_bytes);
861881

862-
/* UCX atomics are only supported on 32 and 64 bit values */
863-
if (ompi_osc_base_is_atomic_size_supported(remote_addr, dt_bytes) &&
864-
(op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op ||
865-
op == &ompi_mpi_op_sum.op)) {
882+
if (osc_is_atomic_dt_op_supported(dt, op, dt_bytes, remote_addr)) {
866883
uint64_t value;
867884
ucp_atomic_fetch_op_t opcode;
868885
bool lock_acquired = false;
@@ -973,6 +990,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
973990
if (ret != OMPI_SUCCESS) {
974991
return ret;
975992
}
993+
temp_count *= target_count;
976994
}
977995
ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent);
978996
temp_addr = free_addr = malloc(temp_extent * temp_count);

0 commit comments

Comments
 (0)