diff --git a/src/mpi.f90 b/src/mpi.f90 index b1eba4e..cebaa35 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -264,14 +264,31 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) end subroutine subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror) - use mpi_c_bindings, only: c_mpi_irecv + use iso_c_binding, only: c_int, c_ptr + use mpi_c_bindings, only: c_mpi_irecv, c_mpi_comm_f2c, get_c_datatype_from_fortran, c_mpi_request_c2f real(8), dimension(:,:) :: buf integer, intent(in) :: count, source, tag integer, intent(in) :: datatype integer, intent(in) :: comm integer, intent(out) :: request integer, optional, intent(out) :: ierror - call c_mpi_irecv(buf, count, datatype, source, tag, comm, request, ierror) + type(c_ptr) :: c_comm + integer(c_int) :: local_ierr + type(c_ptr) :: c_datatype + type(c_ptr) :: c_request + + c_comm = c_mpi_comm_f2c(comm) + c_datatype = get_c_datatype_from_fortran(datatype) + local_ierr = c_mpi_irecv(buf, count, c_datatype, source, tag, c_comm, c_request) + request = c_mpi_request_c2f(c_request) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Irecv failed with error code: ", local_ierr + end if + end if end subroutine subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 4e5f8db..4b86dbd 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -11,7 +11,19 @@ end function c_mpi_comm_f2c function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f") use iso_c_binding, only: c_int, c_ptr type(c_ptr), value :: comm_c - integer :: c_mpi_comm_c2f + integer(c_int) :: c_mpi_comm_c2f + end function + + function get_c_datatype_from_fortran(datatype) bind(C, name="get_c_datatype_from_fortran") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: datatype + type(c_ptr) :: get_c_datatype_from_fortran + end function get_c_datatype_from_fortran + + function c_mpi_request_c2f(request) bind(C, name="MPI_Request_c2f") + use iso_c_binding, only: c_int, c_ptr + type(c_ptr), value :: request + integer(c_int) :: c_mpi_request_c2f end function function c_mpi_init(argc, argv) bind(C, name="MPI_Init") @@ -94,15 +106,15 @@ subroutine c_mpi_isend(buf, count, datatype, dest, tag, comm, request, ierror) b integer(c_int), optional, intent(out) :: ierror end subroutine - subroutine c_mpi_irecv(buf, count, datatype, source, tag, comm, request, ierror) bind(C, name="mpi_irecv_wrapper") - use iso_c_binding, only: c_int, c_double - real(c_double), dimension(*) :: buf - integer(c_int), intent(in) :: count, source, tag - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), intent(out) :: request - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, name="MPI_Irecv") + use iso_c_binding, only: c_int, c_double, c_ptr + real(c_double), dimension(*), intent(out) :: buf + integer(c_int), value :: count, source, tag + type(c_ptr), value :: datatype + type(c_ptr), value :: comm + type(c_ptr), intent(out) :: request + integer(c_int) :: c_mpi_irecv + end function subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) & bind(C, name="mpi_allreduce_wrapper_real") diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index cb4b91c..0d1a978 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -102,17 +102,6 @@ void mpi_isend_wrapper(const double *buf, int *count, int *datatype_f, *request_f = MPI_Request_c2f(request); } -void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f, - int *source, int *tag, int *comm_f, int *request_f, - int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - - MPI_Request request; - *ierror = MPI_Irecv(buf, *count, datatype, *source, *tag, comm, &request); - *request_f = MPI_Request_c2f(request); -} - void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); @@ -200,7 +189,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f, void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *ierror) { - MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f); + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm);