diff --git a/src/mpi.f90 b/src/mpi.f90 index 95a2c85..a341f1d 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -15,6 +15,7 @@ module mpi real(8), parameter :: MPI_IN_PLACE = -1002 integer, parameter :: MPI_SUM = -2300 integer, parameter :: MPI_INFO_NULL = -2000 + integer, parameter :: MPI_STATUS_SIZE = 5 integer :: MPI_STATUS_IGNORE = 0 ! NOTE: I've no idea for how to implement this, refer ! see section 2.5.4 page 21 of mpi40-report.pdf @@ -81,7 +82,8 @@ module mpi end interface interface MPI_Recv - module procedure MPI_Recv_proc + module procedure MPI_Recv_StatusArray_proc + module procedure MPI_Recv_StatusIgnore_proc end interface interface MPI_Waitall @@ -516,16 +518,75 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror end subroutine MPI_Comm_split_type_proc - subroutine MPI_Recv_proc(buf, count, datatype, source, tag, comm, status, ierror) - use mpi_c_bindings, only: c_mpi_recv - real(8), dimension(:) :: buf - integer, intent(in) :: count, source, tag - integer, intent(in) :: datatype - integer, intent(in) :: comm + subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f + real(8), dimension(*), intent(inout) :: buf + integer, intent(in) :: count, source, tag, datatype, comm + integer, intent(out) :: status(MPI_STATUS_SIZE) + integer, optional, intent(out) :: ierror + + integer(c_int) :: local_ierr, status_ierr + type(c_ptr) :: c_dtype, c_comm, c_status + integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status + + ! Convert Fortran handles to C handles. + c_dtype = c_mpi_datatype_f2c(datatype) + c_comm = c_mpi_comm_f2c(comm) + + ! Use a local temporary MPI_Status (as an array of c_int) + c_status = c_loc(tmp_status) + + ! Call the native MPI_Recv. + local_ierr = c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, c_status) + + ! Convert the C MPI_Status to Fortran status. + if (local_ierr == MPI_SUCCESS) then + status_ierr = c_mpi_status_c2f(c_status, status) + end if + + if (present(ierror)) then + ierror = local_ierr + else if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Recv failed with error code: ", local_ierr + end if + + end subroutine MPI_Recv_StatusArray_proc + + subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f + real(8), dimension(*), intent(inout) :: buf + integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status integer, optional, intent(out) :: ierror - call c_mpi_recv(buf, count, datatype, source, tag, comm, status, ierror) - end subroutine + + integer(c_int) :: local_ierr, status_ierr + type(c_ptr) :: c_dtype, c_comm, c_status + integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status + + ! Convert Fortran handles to C handles. + c_dtype = c_mpi_datatype_f2c(datatype) + c_comm = c_mpi_comm_f2c(comm) + + ! Use a local temporary MPI_Status (as an array of c_int) + c_status = c_loc(tmp_status) + + ! Call the native MPI_Recv. + local_ierr = c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, c_status) + + ! Convert the C MPI_Status to Fortran status. + if (local_ierr == MPI_SUCCESS) then + ! status_ierr = c_mpi_status_c2f(c_status, status) + end if + + if (present(ierror)) then + ierror = local_ierr + else if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Recv failed with error code: ", local_ierr + end if + + end subroutine MPI_Recv_StatusIgnore_proc subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror) use mpi_c_bindings, only: c_mpi_waitall diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index d4b9617..21151ab 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -32,6 +32,12 @@ function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") type(c_ptr) :: c_mpi_op_f2c end function c_mpi_op_f2c + function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") + use iso_c_binding, only: c_ptr, c_int + type(c_ptr) :: c_status + integer(c_int) :: f_status(*) ! assumed-size array + integer(c_int) :: c_mpi_status_c2f + end function c_mpi_status_c2f function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: info_f @@ -184,15 +190,17 @@ function c_mpi_comm_split_type(c_comm, split_type, key, c_info, new_comm) bind(C integer(c_int) :: c_mpi_comm_split_type end function c_mpi_comm_split_type - subroutine c_mpi_recv(buf, count, datatype, source, tag, comm, status, ierror) bind(C, name="mpi_recv_wrapper") - use iso_c_binding, only: c_int, c_double - real(c_double), dimension(*) :: buf - integer(c_int), intent(in) :: count, source, tag - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), intent(out) :: status - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, name="MPI_Recv") + use iso_c_binding, only: c_ptr, c_int, c_double + real(c_double), dimension(*), intent(out) :: buf + integer(c_int), value :: count + type(c_ptr), value :: c_dtype + integer(c_int), value :: source + integer(c_int), value :: tag + type(c_ptr), value :: c_comm + type(c_ptr) :: status + integer(c_int) :: c_mpi_recv + end function c_mpi_recv subroutine c_mpi_waitall(count, array_of_requests, array_of_statuses, ierror) bind(C, name="mpi_waitall_wrapper") use iso_c_binding, only: c_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 25be0b4..0c9ca03 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -94,18 +94,6 @@ void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, } } -void mpi_recv_wrapper(double *buf, int *count, int *datatype_f, int *source, - int *tag, int *comm_f, int *status_f, int *ierror) { - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Status status; - *ierror = MPI_Recv(buf, *count, datatype, *source, *tag, comm, &status); - if (*ierror == MPI_SUCCESS) { - MPI_Status_c2f(&status, status_f); - } -} - void mpi_waitall_wrapper(int *count, int *array_of_requests_f, int *array_of_statuses_f, int *ierror) { MPI_Request *array_of_requests; diff --git a/tests/recv_1.f90 b/tests/recv_1.f90 new file mode 100644 index 0000000..e745ac5 --- /dev/null +++ b/tests/recv_1.f90 @@ -0,0 +1,34 @@ +program test_recv_1 + use mpi + implicit none + + integer :: ierr, rank, size, comm, tag + real(8), dimension(5) :: buf + integer, dimension(MPI_STATUS_SIZE) :: status + + comm = MPI_COMM_WORLD + tag = 100 + + call MPI_Init(ierr) + call MPI_Comm_rank(comm, rank, ierr) + call MPI_Comm_size(comm, size, ierr) + + if (size < 2) then + print *, "This test works best with at least 2 MPI processes." + else + if (rank == 0) then + ! Rank 0: Prepare data and send synchronously. + buf = (/ 1.0d0, 2.0d0, 3.0d0, 4.0d0, 5.0d0 /) + call MPI_Ssend(buf, 5, MPI_REAL8, 1, tag, comm, ierr) + print *, "Rank 0 sent data." + else if (rank == 1) then + ! Rank 1: Receive the data. + buf = 0.0d0 ! initialize to zeros + call MPI_Recv(buf, 5, MPI_REAL8, 0, tag, comm, status, ierr) + print *, "Rank 1 received data: ", buf + end if + end if + + call MPI_Finalize(ierr) + +end program test_recv_1 \ No newline at end of file diff --git a/tests/ssend_1.f90 b/tests/ssend_1.f90 index 6a2afd0..fbedd5e 100644 --- a/tests/ssend_1.f90 +++ b/tests/ssend_1.f90 @@ -6,6 +6,7 @@ program ssend_example integer :: rank, size, ierr, i real(8) :: buffer(10) integer :: tag = 100 + integer, dimension(MPI_STATUS_SIZE) :: status ! allocate(buffer(10)) ! Initialize MPI environment @@ -47,7 +48,7 @@ program ssend_example buffer = 0 ! Receive the message - call MPI_Recv(buffer, 10, MPI_REAL8, 0, tag, MPI_COMM_WORLD, MPI_STATUS_IGNORE, ierr) + call MPI_Recv(buffer, 10, MPI_REAL8, 0, tag, MPI_COMM_WORLD, status, ierr) print *, "Process 1: Received data:" ! write(*, '(10I5)') buffer