From bb4b03afa9070c76ffc3093ec1351500ac7fc83f Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Wed, 2 Apr 2025 19:20:28 +0530 Subject: [PATCH 1/2] Call MPI_AllReduce instead of C-wrapper for the 1D-recvbuf --- src/mpi.f90 | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index f9aaae6..e597975 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -60,7 +60,7 @@ module mpi interface MPI_Allreduce module procedure MPI_Allreduce_scalar - module procedure MPI_Allreduce_1d + module procedure MPI_Allreduce_1D_recv_proc module procedure MPI_Allreduce_array_real module procedure MPI_Allreduce_array_int end interface @@ -431,14 +431,36 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier end if end subroutine - subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_1d - real(8), intent(in) :: sendbuf - real(8), dimension(:), intent(out) :: recvbuf + subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + real(8), intent(in), target :: sendbuf + real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) - end subroutine + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + if (sendbuf == MPI_IN_PLACE) then + sendbuf_ptr = c_mpi_in_place_f2c(sendbuf) + else + sendbuf_ptr = c_loc(sendbuf) + end if + recvbuf_ptr = c_loc(recvbuf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_op = c_mpi_op_f2c(op) + c_comm = c_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_scalar failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) use mpi_c_bindings, only: c_mpi_allreduce_array_real From 6cfbc3df9081b9faaea8d0a52f29cb97a2e808ff Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Wed, 2 Apr 2025 19:40:57 +0530 Subject: [PATCH 2/2] Use proper name to avoid ambiguity --- src/mpi.f90 | 10 +++++----- src/mpi_c_bindings.f90 | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index e597975..cbb4dab 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -402,7 +402,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -420,7 +420,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier c_op = c_mpi_op_f2c(op) c_comm = c_mpi_comm_f2c(comm) - local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) if (present(ierror)) then ierror = local_ierr @@ -433,7 +433,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -451,13 +451,13 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com c_op = c_mpi_op_f2c(op) c_comm = c_mpi_comm_f2c(comm) - local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) if (present(ierror)) then ierror = local_ierr else if (local_ierr /= MPI_SUCCESS) then - print *, "MPI_Allreduce_scalar failed with error code: ", local_ierr + print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr end if end if end subroutine MPI_Allreduce_1D_recv_proc diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 0e6a8e5..1b7c0d1 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -132,14 +132,14 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n integer(c_int) :: c_mpi_irecv end function - function c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm) & + function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) & bind(C, name="MPI_Allreduce") use iso_c_binding, only: c_int, c_double, c_ptr type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: count type(c_ptr), value :: datatype, op, comm - integer(c_int) :: c_mpi_allreduce_scalar + integer(c_int) :: c_mpi_allreduce end function subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) &