diff --git a/src/mpi.f90 b/src/mpi.f90 index 75fef4d..48d1e13 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -312,25 +312,61 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, end subroutine subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) - use mpi_c_bindings, only: c_mpi_isend - real(8), dimension(:, :), intent(in) :: buf + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f + real(8), dimension(:, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype integer, intent(in) :: comm integer, intent(out) :: request integer, optional, intent(out) :: ierror - call c_mpi_isend(buf, count, datatype, dest, tag, comm, request, ierror) + type(c_ptr) :: buf_ptr + type(c_ptr) :: c_datatype, c_comm, c_request + integer(c_int) :: local_ierr + + buf_ptr = c_loc(buf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_comm = c_mpi_comm_f2c(comm) + local_ierr = c_mpi_isend(buf_ptr, count, c_datatype, dest, tag, c_comm, c_request) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Isend_2d failed with error code: ", local_ierr + end if + end if + + request = c_mpi_request_c2f(c_request) end subroutine subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) - use mpi_c_bindings, only: c_mpi_isend - real(8), dimension(:, :, :), intent(in) :: buf + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f + real(8), dimension(:, :, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype integer, intent(in) :: comm integer, intent(out) :: request integer, optional, intent(out) :: ierror - call c_mpi_isend(buf, count, datatype, dest, tag, comm, request, ierror) + type(c_ptr) :: buf_ptr + type(c_ptr) :: c_datatype, c_comm, c_request + integer(c_int) :: local_ierr + + buf_ptr = c_loc(buf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_comm = c_mpi_comm_f2c(comm) + local_ierr = c_mpi_isend(buf_ptr, count, c_datatype, dest, tag, c_comm, c_request) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Isend_2d failed with error code: ", local_ierr + end if + end if + + request = c_mpi_request_c2f(c_request) end subroutine subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 36bdd37..ff36210 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -100,15 +100,15 @@ function c_mpi_allgather_real(sendbuf, sendcount, sendtype, recvbuf, & integer(c_int) :: c_mpi_allgather_real end function - subroutine c_mpi_isend(buf, count, datatype, dest, tag, comm, request, ierror) bind(C, name="mpi_isend_wrapper") - use iso_c_binding, only: c_int, c_double - real(c_double), dimension(*), intent(in) :: buf - integer(c_int), intent(in) :: count, dest, tag - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), intent(out) :: request - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_isend(buf, count, datatype, dest, tag, comm, request) bind(C, name="MPI_Isend") + use iso_c_binding, only: c_int, c_double, c_ptr + type(c_ptr), value :: buf + integer(c_int), value :: count, dest, tag + type(c_ptr), value :: datatype + type(c_ptr), value :: comm + type(c_ptr), intent(out) :: request + integer(c_int) :: c_mpi_isend + end function function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, name="MPI_Irecv") use iso_c_binding, only: c_int, c_double, c_ptr diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index ee8bfec..785192d 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -55,17 +55,6 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { } } -void mpi_isend_wrapper(const double *buf, int *count, int *datatype_f, - int *dest, int *tag, int *comm_f, int *request_f, - int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - - MPI_Request request; - *ierror = MPI_Isend(buf, *count, datatype, *dest, *tag, comm, &request); - *request_f = MPI_Request_c2f(request); -} - void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); diff --git a/tests/isend_1.f90 b/tests/isend_1.f90 new file mode 100644 index 0000000..7b7e46b --- /dev/null +++ b/tests/isend_1.f90 @@ -0,0 +1,43 @@ +program isend_1 + use mpi + implicit none + + integer, parameter :: NROWS = 100 + integer, parameter :: NCOLS = 50 + real(8) :: send_buf(NROWS, NCOLS) + integer :: rank, size, ierr + integer :: dest, tag + integer :: request + + call MPI_INIT(ierr) + + call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr) + call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr) + + if (size < 2) then + if (rank == 0) then + print *, 'This program requires at least 2 processes' + end if + call MPI_FINALIZE(ierr) + stop + end if + + send_buf = real(rank, 8) + 0.1d0 + + dest = mod(rank + 1, size) + tag = 0 + + if (rank == 0) then + print *, 'Starting non-blocking send example without wait' + end if + + call MPI_ISEND(send_buf, NROWS*NCOLS, MPI_REAL8, dest, & + tag, MPI_COMM_WORLD, request, ierr) + + print *, 'Rank ', rank, ' continuing work while sending to ', dest + + print *, 'Rank ', rank, ' finished work, send may still be in progress' + + call MPI_FINALIZE(ierr) + +end program isend_1