diff --git a/src/mpi.f90 b/src/mpi.f90 index 3325471..7f7be74 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -314,11 +314,25 @@ function MPI_Wtime_proc() result(time) end function subroutine MPI_Barrier_proc(comm, ierror) - use mpi_c_bindings, only: c_mpi_barrier + use mpi_c_bindings, only: c_mpi_barrier, c_mpi_comm_f2c + use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out), optional :: ierror - call c_mpi_barrier(comm, ierror) - end subroutine + type(c_ptr) :: c_comm + integer(c_int) :: local_ierr + + ! Convert Fortran handle to C handle + c_comm = c_mpi_comm_f2c(comm) + local_ierr = c_mpi_barrier(c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Barrier failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Barrier_proc subroutine MPI_Comm_rank_proc(comm, rank, ierror) use mpi_c_bindings, only: c_mpi_comm_rank diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 895da37..b827e2a 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -132,12 +132,17 @@ function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime") use iso_c_binding, only: c_double real(c_double) :: time end function - - subroutine c_mpi_barrier(comm, ierror) bind(C, name="mpi_barrier_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(in) :: comm - integer(c_int), intent(out), optional :: ierror - end subroutine + function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: comm_f + type(c_ptr) :: c_mpi_comm_f2c ! MPI_Comm as pointer + end function c_mpi_comm_f2c + + function c_mpi_barrier(comm) bind(C, name="MPI_Barrier") + use iso_c_binding, only: c_ptr, c_int + type(c_ptr), value :: comm ! MPI_Comm as pointer + integer(c_int) :: c_mpi_barrier + end function c_mpi_barrier subroutine c_mpi_comm_rank(comm, rank, ierror) bind(C, name="mpi_comm_rank_wrapper") use iso_c_binding, only: c_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 2eef39e..8f0d80e 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -228,10 +228,10 @@ void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, } } -void mpi_barrier_wrapper(int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - *ierror = MPI_Barrier(comm); -} +// void mpi_barrier_wrapper(int *comm_f, int *ierror) { +// MPI_Comm comm = MPI_Comm_f2c(*comm_f); +// *ierror = MPI_Barrier(comm); +// } void mpi_comm_rank_wrapper(int *comm_f, int *rank, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f);