From 6ebb7136493fb653fdecbf3aa6bf551000cdbd9e Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Wed, 2 Apr 2025 23:55:17 +0530 Subject: [PATCH] Remove C-wrappr for MPI_Allreduce_int_array --- src/mpi.f90 | 33 +++++++++++++++++++++++++-------- src/mpi_c_bindings.f90 | 19 ------------------- src/mpi_wrapper.c | 15 --------------- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 0bae87d..df0d56f 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -62,7 +62,7 @@ module mpi module procedure MPI_Allreduce_scalar module procedure MPI_Allreduce_1D_recv_proc module procedure MPI_Allreduce_1D_real_proc - module procedure MPI_Allreduce_array_int + module procedure MPI_Allreduce_1D_int_proc end interface interface MPI_Wtime @@ -489,15 +489,32 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com end if end subroutine MPI_Allreduce_1D_real_proc - subroutine MPI_Allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_array_int - ! Declare both send and recv as arrays: - integer, dimension(:), intent(in) :: sendbuf - integer, dimension(:), intent(out) :: recvbuf + subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + integer, dimension(:), intent(in), target :: sendbuf + integer, dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) - end subroutine MPI_Allreduce_array_int + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + sendbuf_ptr = c_loc(sendbuf) + recvbuf_ptr = c_loc(recvbuf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_op = c_mpi_op_f2c(op) + c_comm = c_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Allreduce_1D_int_proc function MPI_Wtime_proc() result(time) use mpi_c_bindings, only: c_mpi_wtime diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index b7e3ad9..1fc55c2 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -141,25 +141,6 @@ function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) & type(c_ptr), value :: datatype, op, comm integer(c_int) :: c_mpi_allreduce end function - - - subroutine c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) & - bind(C, name="mpi_allreduce_wrapper_real") - use iso_c_binding, only: c_int, c_double - real(c_double), dimension(*), intent(in) :: sendbuf - real(c_double), dimension(*), intent(out) :: recvbuf - integer(c_int), intent(in) :: count, datatype, op, comm - integer(c_int), intent(out), optional :: ierror - end subroutine c_mpi_allreduce_array_real - - subroutine c_mpi_allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) & - bind(C, name="mpi_allreduce_wrapper_int") - use iso_c_binding, only: c_int, c_double - integer(c_int), dimension(*), intent(in) :: sendbuf - integer(c_int), dimension(*), intent(out) :: recvbuf - integer(c_int), intent(in) :: count, datatype, op, comm - integer(c_int), intent(out), optional :: ierror - end subroutine c_mpi_allreduce_array_int function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime") use iso_c_binding, only: c_double diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index f16cc95..e8fbe87 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -61,21 +61,6 @@ void* get_c_mpi_inplace_from_fortran(double sendbuf) { return MPI_IN_PLACE; } -void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, - int *datatype_f, int *op_f, int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - datatype = MPI_INT; - - MPI_Op op = MPI_Op_f2c(*op_f); - - if (*sendbuf == FORTRAN_MPI_IN_PLACE) { - *ierror = MPI_Allreduce(MPI_IN_PLACE , recvbuf, *count, datatype, MPI_SUM, comm); - } else { - *ierror = MPI_Allreduce(sendbuf , recvbuf, *count, datatype, MPI_SUM, comm); - } -} - void mpi_waitall_wrapper(int *count, int *array_of_requests_f, int *array_of_statuses_f, int *ierror) { MPI_Request *array_of_requests;