diff --git a/src/mpi.f90 b/src/mpi.f90 index 7f7be74..b7c6f02 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -335,11 +335,24 @@ subroutine MPI_Barrier_proc(comm, ierror) end subroutine MPI_Barrier_proc subroutine MPI_Comm_rank_proc(comm, rank, ierror) - use mpi_c_bindings, only: c_mpi_comm_rank + use iso_c_binding, only: c_int, c_ptr + use mpi_c_bindings, only: c_mpi_comm_rank, c_mpi_comm_f2c integer, intent(in) :: comm integer, intent(out) :: rank integer, optional, intent(out) :: ierror - call c_mpi_comm_rank(comm, rank, ierror) + type(c_ptr) :: c_comm + integer(c_int) :: local_ierr + + c_comm = c_mpi_comm_f2c(comm) + local_ierr = c_mpi_comm_rank(c_comm, rank) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Comm_rank failed with error code: ", local_ierr + end if + end if end subroutine subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index b827e2a..275d6a2 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -144,12 +144,12 @@ function c_mpi_barrier(comm) bind(C, name="MPI_Barrier") integer(c_int) :: c_mpi_barrier end function c_mpi_barrier - subroutine c_mpi_comm_rank(comm, rank, ierror) bind(C, name="mpi_comm_rank_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(in) :: comm + function c_mpi_comm_rank(comm, rank) bind(C, name="MPI_Comm_rank") + use iso_c_binding, only: c_int, c_ptr + type(c_ptr), value :: comm integer(c_int), intent(out) :: rank - integer(c_int), optional, intent(out) :: ierror - end subroutine + integer(c_int) :: c_mpi_comm_rank + end function c_mpi_comm_rank subroutine c_mpi_comm_split_type(comm, split_type, key, info, newcomm, ierror) bind(C, name="mpi_comm_split_type_wrapper") use iso_c_binding, only: c_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 8f0d80e..a94d315 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -228,16 +228,6 @@ void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, } } -// void mpi_barrier_wrapper(int *comm_f, int *ierror) { -// MPI_Comm comm = MPI_Comm_f2c(*comm_f); -// *ierror = MPI_Barrier(comm); -// } - -void mpi_comm_rank_wrapper(int *comm_f, int *rank, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - *ierror = MPI_Comm_rank(comm, rank); -} - void mpi_comm_split_type_wrapper(int *comm_f, int *split_type, int *key, int *info_f, int *newcomm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f);