diff --git a/src/mpi.f90 b/src/mpi.f90 index f9aaae6..cbb4dab 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -60,7 +60,7 @@ module mpi interface MPI_Allreduce module procedure MPI_Allreduce_scalar - module procedure MPI_Allreduce_1d + module procedure MPI_Allreduce_1D_recv_proc module procedure MPI_Allreduce_array_real module procedure MPI_Allreduce_array_int end interface @@ -402,7 +402,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -420,7 +420,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier c_op = c_mpi_op_f2c(op) c_comm = c_mpi_comm_f2c(comm) - local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) if (present(ierror)) then ierror = local_ierr @@ -431,14 +431,36 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier end if end subroutine - subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_1d - real(8), intent(in) :: sendbuf - real(8), dimension(:), intent(out) :: recvbuf + subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + real(8), intent(in), target :: sendbuf + real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) - end subroutine + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + if (sendbuf == MPI_IN_PLACE) then + sendbuf_ptr = c_mpi_in_place_f2c(sendbuf) + else + sendbuf_ptr = c_loc(sendbuf) + end if + recvbuf_ptr = c_loc(recvbuf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_op = c_mpi_op_f2c(op) + c_comm = c_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) use mpi_c_bindings, only: c_mpi_allreduce_array_real diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 0e6a8e5..1b7c0d1 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -132,14 +132,14 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n integer(c_int) :: c_mpi_irecv end function - function c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm) & + function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) & bind(C, name="MPI_Allreduce") use iso_c_binding, only: c_int, c_double, c_ptr type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: count type(c_ptr), value :: datatype, op, comm - integer(c_int) :: c_mpi_allreduce_scalar + integer(c_int) :: c_mpi_allreduce end function subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) &