diff --git a/src/mpi.f90 b/src/mpi.f90 index 73bf72e..f64d0f2 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -21,6 +21,7 @@ module mpi real(8), parameter :: MPI_IN_PLACE = -1002 integer, parameter :: MPI_SUM = -2300 integer, parameter :: MPI_MAX = -2301 + integer, parameter :: MPI_LOR = -2302 integer, parameter :: MPI_INFO_NULL = -2000 integer, parameter :: MPI_STATUS_SIZE = 5 integer :: MPI_STATUS_IGNORE = 0 @@ -99,6 +100,7 @@ module mpi module procedure MPI_Allreduce_1D_recv_proc module procedure MPI_Allreduce_1D_real_proc module procedure MPI_Allreduce_1D_int_proc + module procedure MPI_Allreduce_scalar_logical_proc end interface interface MPI_Gatherv @@ -168,14 +170,16 @@ module mpi contains integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op) - use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max + use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max, c_mpi_lor integer, intent(in) :: op_f if (op_f == MPI_SUM) then c_op = c_mpi_sum else if (op_f == MPI_MAX) then c_op = c_MPI_MAX + else if (op_f == MPI_LOR) then + c_op = c_mpi_lor else - c_op = c_mpi_op_f2c(op_f) + c_op = c_mpi_op_f2c(op_f) ! For other operations, use the C binding end if end function @@ -795,6 +799,35 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm end if end subroutine MPI_Allreduce_1D_int_proc + subroutine MPI_Allreduce_scalar_logical_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_comm_f2c + logical, intent(in), target :: sendbuf + logical, intent(out), target :: recvbuf + integer, intent(in) :: count, datatype, op, comm + integer, intent(out), optional :: ierror + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + sendbuf_ptr = c_loc(sendbuf) + recvbuf_ptr = c_loc(recvbuf) + c_datatype = handle_mpi_datatype_f2c(datatype) + c_op = handle_mpi_op_f2c(op) + + c_comm = handle_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Allreduce_scalar_logical_proc + function MPI_Wtime_proc() result(time) use mpi_c_bindings, only: c_mpi_wtime real(8) :: time diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 05b4f8e..fb90a02 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -12,17 +12,21 @@ module mpi_c_bindings type(c_ptr), bind(C, name="c_MPI_STATUSES_IGNORE") :: c_mpi_statuses_ignore type(c_ptr), bind(C, name="c_MPI_IN_PLACE") :: c_mpi_in_place integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INFO_NULL") :: c_mpi_info_null + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_DOUBLE") :: c_mpi_double integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_FLOAT") :: c_mpi_float integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int - integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world - integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null - integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum - integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOR") :: c_mpi_lor + + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null + interface function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c") diff --git a/src/mpi_constants.c b/src/mpi_constants.c index fde0534..bb39969 100644 --- a/src/mpi_constants.c +++ b/src/mpi_constants.c @@ -26,6 +26,8 @@ MPI_Op c_MPI_SUM = MPI_SUM; MPI_Op c_MPI_MAX = MPI_MAX; +MPI_Op c_MPI_LOR = MPI_LOR; + // Communicators Declarations MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL; diff --git a/tests/allreduce_lor.f90 b/tests/allreduce_lor.f90 new file mode 100644 index 0000000..e3923ac --- /dev/null +++ b/tests/allreduce_lor.f90 @@ -0,0 +1,26 @@ +program mre_mpi_lor_allreduce + use mpi + implicit none + + integer :: ierr, rank, size + logical :: local_flag, global_flag + + call MPI_INIT(ierr) + if (ierr /= MPI_SUCCESS) error stop "MPI_INIT failed" + + call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr) + call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr) + + ! Initialize the local flag: True if this is the 0th rank, False otherwise + local_flag = (rank == 0) + + ! Perform logical OR reduction across all processes + call MPI_ALLREDUCE(local_flag, global_flag, 1, MPI_LOGICAL, MPI_LOR, MPI_COMM_WORLD, ierr) + if (global_flag .neqv. .true.) error stop "MPI_ALLREDUCE failed" + + print *, 'Rank', rank, ': global_flag =', global_flag + + call MPI_FINALIZE(ierr) + if (ierr /= MPI_SUCCESS) error stop "MPI_FINALIZE failed" + +end program mre_mpi_lor_allreduce \ No newline at end of file