From 3873a4ef8ad32885ae8aabdf4f41792a20983c9d Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Wed, 2 Apr 2025 15:50:17 +0530 Subject: [PATCH 1/2] remove C wrapper of MPI_Allreduce_scalar --- src/mpi.f90 | 26 ++++++++++++++++++++++---- src/mpi_c_bindings.f90 | 17 +++++++++-------- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 0f02606..918a98c 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -399,12 +399,30 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr end subroutine subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_scalar - real(8), intent(in) :: sendbuf - real(8), intent(out) :: recvbuf + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c + real(8), intent(in), target :: sendbuf + real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + sendbuf_ptr = c_loc(sendbuf) + recvbuf_ptr = c_loc(recvbuf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_op = c_mpi_op_f2c(op) + c_comm = c_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_scalar failed with error code: ", local_ierr + end if + end if end subroutine subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index ff36210..8a7d9c2 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -120,14 +120,15 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n integer(c_int) :: c_mpi_irecv end function - subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) & - bind(C, name="mpi_allreduce_wrapper_real") - use iso_c_binding, only: c_int, c_double - real(c_double), intent(in) :: sendbuf - real(c_double), intent(out) :: recvbuf - integer(c_int), intent(in) :: count, datatype, op, comm - integer(c_int), intent(out), optional :: ierror - end subroutine + function c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm) & + bind(C, name="MPI_Allreduce") + use iso_c_binding, only: c_int, c_double, c_ptr + type(c_ptr), value :: sendbuf + type(c_ptr), value :: recvbuf + integer(c_int), value :: count + type(c_ptr), value :: datatype, op, comm + integer(c_int) :: c_mpi_allreduce_scalar + end function subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) & bind(C, name="mpi_allreduce_wrapper_real") From 44e618fdb444a73913fcc8845ebbb17b1c3a5700 Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Wed, 2 Apr 2025 18:55:14 +0530 Subject: [PATCH 2/2] Handle MPI_IN_PLACE --- src/mpi.f90 | 8 ++++++-- src/mpi_c_bindings.f90 | 6 ++++++ src/mpi_wrapper.c | 4 ++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 918a98c..f4c6a62 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -400,7 +400,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -408,7 +408,11 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm integer(c_int) :: local_ierr - sendbuf_ptr = c_loc(sendbuf) + if (sendbuf == MPI_IN_PLACE) then + sendbuf_ptr = c_mpi_in_place_f2c(sendbuf) + else + sendbuf_ptr = c_loc(sendbuf) + end if recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 8a7d9c2..e13303a 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -38,6 +38,12 @@ function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") type(c_ptr) :: c_mpi_info_f2c end function c_mpi_info_f2c + function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran") + use iso_c_binding, only: c_double, c_ptr + real(c_double), value :: in_place_f + type(c_ptr) :: c_mpi_in_place_f2c + end function c_mpi_in_place_f2c + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 8ebe295..5137a66 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -57,6 +57,10 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { } } +void* get_c_mpi_inplace_from_fortran(double sendbuf) { + return MPI_IN_PLACE; +} + void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f);