diff --git a/src/mpi.f90 b/src/mpi.f90 index cbb4dab..0bae87d 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -61,7 +61,7 @@ module mpi interface MPI_Allreduce module procedure MPI_Allreduce_scalar module procedure MPI_Allreduce_1D_recv_proc - module procedure MPI_Allreduce_array_real + module procedure MPI_Allreduce_1D_real_proc module procedure MPI_Allreduce_array_int end interface @@ -462,15 +462,32 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com end if end subroutine MPI_Allreduce_1D_recv_proc - subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_array_real - ! Declare both send and recv as arrays: - real(8), dimension(:), intent(in) :: sendbuf - real(8), dimension(:), intent(out) :: recvbuf + subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + real(8), dimension(:), intent(in), target :: sendbuf + real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) - end subroutine MPI_Allreduce_array_real + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + sendbuf_ptr = c_loc(sendbuf) + recvbuf_ptr = c_loc(recvbuf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_op = c_mpi_op_f2c(op) + c_comm = c_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Allreduce_1D_real_proc subroutine MPI_Allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) use mpi_c_bindings, only: c_mpi_allreduce_array_int diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 1b7c0d1..b7e3ad9 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -141,15 +141,7 @@ function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) & type(c_ptr), value :: datatype, op, comm integer(c_int) :: c_mpi_allreduce end function - - subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) & - bind(C, name="mpi_allreduce_wrapper_real") - use iso_c_binding, only: c_int, c_double - real(c_double), intent(in) :: sendbuf - real(c_double), dimension(*), intent(out) :: recvbuf - integer(c_int), intent(in) :: count, datatype, op, comm - integer(c_int), intent(out), optional :: ierror - end subroutine + subroutine c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) & bind(C, name="mpi_allreduce_wrapper_real") diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 11d6d08..f16cc95 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -61,28 +61,6 @@ void* get_c_mpi_inplace_from_fortran(double sendbuf) { return MPI_IN_PLACE; } -void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, - int *datatype_f, int *op_f, int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - - // I'm a little doubtful: as how would it identify that this part - // is supposed to be for MPI_SUM? - MPI_Op op = MPI_Op_f2c(*op_f); - /* - hard-code values here: - 1. We've hard-coded op as "MPI_SUM" for now, as in POT3D codebase, it's always - used as MPI_SUM - 2. the first argument (i.e. sendbuf) as "MPI_IN_PLACE" for now as it's always - used as such in POT3D codebase - */ - if (*sendbuf == FORTRAN_MPI_IN_PLACE) { - *ierror = MPI_Allreduce(MPI_IN_PLACE , recvbuf, *count, datatype, MPI_SUM, comm); - } else { - *ierror = MPI_Allreduce(sendbuf , recvbuf, *count, datatype, MPI_SUM, comm); - } -} - void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f);