@@ -378,6 +378,35 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module
378
378
return ret ;
379
379
}
380
380
381
+ static inline
382
+ bool osc_is_atomic_dt_op_supported (
383
+ struct ompi_datatype_t * dt ,
384
+ struct ompi_op_t * op ,
385
+ size_t dt_bytes ,
386
+ uint64_t remote_addr )
387
+ {
388
+ /* UCX atomics are only supported on 32 and 64 bit values */
389
+ if (!ompi_datatype_is_predefined (dt ) ||
390
+ !ompi_osc_base_is_atomic_size_supported (remote_addr , dt_bytes )) {
391
+ return false;
392
+ }
393
+ /* Hardware-based atomic add for floating point is not supported */
394
+ else if ((
395
+ op == & ompi_mpi_op_no_op .op
396
+ || op == & ompi_mpi_op_replace .op
397
+ || op == & ompi_mpi_op_sum .op
398
+ )
399
+ && !(
400
+ op == & ompi_mpi_op_sum .op
401
+ && (dt == MPI_FLOAT || dt == MPI_DOUBLE
402
+ || dt == MPI_LONG_DOUBLE || dt == MPI_FLOAT_INT )
403
+ )) {
404
+ return true;
405
+ }
406
+
407
+ return false;
408
+ }
409
+
381
410
static inline
382
411
bool use_atomic_op (
383
412
ompi_osc_ucx_module_t * module ,
@@ -388,25 +417,16 @@ bool use_atomic_op(
388
417
int origin_count ,
389
418
int target_count )
390
419
{
420
+ size_t origin_dt_bytes ;
391
421
392
- if (module -> acc_single_intrinsic &&
393
- ompi_datatype_is_predefined (origin_dt ) &&
394
- origin_count == 1 &&
395
- (op == & ompi_mpi_op_replace .op ||
396
- op == & ompi_mpi_op_sum .op ||
397
- op == & ompi_mpi_op_no_op .op )) {
398
- size_t origin_dt_bytes ;
399
- size_t target_dt_bytes ;
422
+ if (!module -> acc_single_intrinsic || origin_count != 1 || target_count != 1
423
+ || origin_dt != target_dt ) {
424
+ return false;
425
+ } else {
400
426
ompi_datatype_type_size (origin_dt , & origin_dt_bytes );
401
- ompi_datatype_type_size (target_dt , & target_dt_bytes );
402
- /* UCX only supports 32 and 64-bit operands atm */
403
- if (ompi_osc_base_is_atomic_size_supported (remote_addr , origin_dt_bytes ) &&
404
- origin_dt_bytes == target_dt_bytes && origin_count == target_count ) {
405
- return true;
406
- }
427
+ return osc_is_atomic_dt_op_supported (origin_dt , op , origin_dt_bytes ,
428
+ remote_addr );
407
429
}
408
-
409
- return false;
410
430
}
411
431
412
432
static int do_atomic_op_intrinsic (
@@ -859,10 +879,7 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr,
859
879
uint64_t remote_addr = (module -> addrs [target ]) + target_disp * OSC_UCX_GET_DISP (module , target );
860
880
ompi_datatype_type_size (dt , & dt_bytes );
861
881
862
- /* UCX atomics are only supported on 32 and 64 bit values */
863
- if (ompi_osc_base_is_atomic_size_supported (remote_addr , dt_bytes ) &&
864
- (op == & ompi_mpi_op_no_op .op || op == & ompi_mpi_op_replace .op ||
865
- op == & ompi_mpi_op_sum .op )) {
882
+ if (osc_is_atomic_dt_op_supported (dt , op , dt_bytes , remote_addr )) {
866
883
uint64_t value ;
867
884
ucp_atomic_fetch_op_t opcode ;
868
885
bool lock_acquired = false;
@@ -973,6 +990,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
973
990
if (ret != OMPI_SUCCESS ) {
974
991
return ret ;
975
992
}
993
+ temp_count *= target_count ;
976
994
}
977
995
ompi_datatype_get_true_extent (temp_dt , & temp_lb , & temp_extent );
978
996
temp_addr = free_addr = malloc (temp_extent * temp_count );
0 commit comments