diff --git a/src/mpi.f90 b/src/mpi.f90 index 7f7be74..cee58d9 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -178,17 +178,21 @@ subroutine MPI_Finalize_proc(ierr) end if end subroutine - subroutine MPI_Comm_size_proc(comm, size, ierr) - use mpi_c_bindings, only: c_mpi_comm_size + subroutine MPI_Comm_size_proc(comm, size, ierror) + use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c + use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out) :: size - integer, optional, intent(out) :: ierr + integer, optional, intent(out) :: ierror integer :: local_ierr - if (present(ierr)) then - call c_mpi_comm_size(comm, size, ierr) + type(c_ptr) :: c_comm + + c_comm = c_mpi_comm_f2c(comm) + local_ierr = c_mpi_comm_size(c_comm, size) + if (present(ierror)) then + ierror = local_ierr else - call c_mpi_comm_size(comm, size, local_ierr) - if (local_ierr /= 0) then + if (local_ierr /= MPI_SUCCESS) then print *, "MPI_Comm_size failed with error code: ", local_ierr end if end if diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index b827e2a..19eb336 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -2,6 +2,12 @@ module mpi_c_bindings implicit none interface + function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: comm_f + type(c_ptr) :: c_mpi_comm_f2c ! MPI_Comm as pointer + end function c_mpi_comm_f2c + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified @@ -25,12 +31,12 @@ integer(c_int) function c_mpi_finalize() bind(C, name="MPI_Finalize") use iso_c_binding, only : c_int end function c_mpi_finalize - subroutine c_mpi_comm_size(comm, size, ierr) bind(C, name="mpi_comm_size_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(in) :: comm + function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size") + use iso_c_binding, only: c_int, c_ptr + type(c_ptr), value :: comm integer(c_int), intent(out) :: size - integer(c_int), intent(out) :: ierr - end subroutine c_mpi_comm_size + integer(c_int) :: c_mpi_comm_size + end function c_mpi_comm_size subroutine c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_int_wrapper") use iso_c_binding, only: c_int @@ -132,11 +138,6 @@ function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime") use iso_c_binding, only: c_double real(c_double) :: time end function - function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran") - use iso_c_binding, only: c_int, c_ptr - integer(c_int), value :: comm_f - type(c_ptr) :: c_mpi_comm_f2c ! MPI_Comm as pointer - end function c_mpi_comm_f2c function c_mpi_barrier(comm) bind(C, name="MPI_Barrier") use iso_c_binding, only: c_ptr, c_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 8f0d80e..15084eb 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -32,11 +32,6 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { } } -void mpi_comm_size_wrapper(int *comm_f, int *size, int *ierr) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - *ierr = MPI_Comm_size(comm, size); -} - void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); MPI_Datatype datatype;