diff --git a/src/mpi.f90 b/src/mpi.f90 index 7fdcf54..81a5bcc 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -605,7 +605,7 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror end if end if - end subroutine MPI_Comm_split_type_proc + end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc @@ -636,14 +636,14 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st end if if (present(ierror)) then - ierror = local_ierr + ierror = local_ierr else if (local_ierr /= MPI_SUCCESS) then - print *, "MPI_Recv failed with error code: ", local_ierr + print *, "MPI_Recv failed with error code: ", local_ierr end if - end subroutine MPI_Recv_StatusArray_proc + end subroutine MPI_Recv_StatusArray_proc - subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) + subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f real(8), dimension(*), intent(inout), target :: buf @@ -672,21 +672,54 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s end if if (present(ierror)) then - ierror = local_ierr + ierror = local_ierr else if (local_ierr /= MPI_SUCCESS) then - print *, "MPI_Recv failed with error code: ", local_ierr + print *, "MPI_Recv failed with error code: ", local_ierr end if - - end subroutine MPI_Recv_StatusIgnore_proc + + end subroutine MPI_Recv_StatusIgnore_proc subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror) - use mpi_c_bindings, only: c_mpi_waitall + use iso_c_binding, only: c_int, c_ptr + use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f, c_mpi_statuses_ignore integer, intent(in) :: count - integer, intent(inout) :: array_of_requests(count) - integer, intent(out) :: array_of_statuses(*) + integer, dimension(count), intent(inout) :: array_of_requests + integer, dimension(*), intent(out) :: array_of_statuses integer, optional, intent(out) :: ierror - call c_mpi_waitall(count, array_of_requests, array_of_statuses, ierror) - end subroutine + integer :: arr_request_item_kind_4 + integer(kind=MPI_HANDLE_KIND) :: arr_request_item_kind_mpi_handle_kind + + integer(c_int) :: local_ierr, status_ierr + integer :: i + + ! Allocate temporary arrays for the C representations. + integer(kind=MPI_HANDLE_KIND), dimension(count) :: c_requests + type(c_ptr) :: MPI_STATUSES_IGNORE_from_c + + MPI_STATUSES_IGNORE_from_c = c_mpi_statuses_ignore() + + ! Convert Fortran requests to C requests. + do i = 1, count + arr_request_item_kind_4 = array_of_requests(i) + c_requests(i) = c_mpi_request_f2c(arr_request_item_kind_4) + end do + + ! Call the native MPI_Waitall. + local_ierr = c_mpi_waitall(count, c_requests, MPI_STATUSES_IGNORE_from_c) + + ! Convert the C requests back to Fortran handles. + do i = 1, count + arr_request_item_kind_mpi_handle_kind = c_requests(i) + array_of_requests(i) = c_mpi_request_c2f(arr_request_item_kind_mpi_handle_kind) + end do + + if (present(ierror)) then + ierror = local_ierr + else if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Waitall failed with error code: ", local_ierr + end if + + end subroutine MPI_Waitall_proc subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) use iso_c_binding, only: c_int, c_ptr diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index c81e85f..7738aab 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -26,6 +26,12 @@ function c_mpi_request_c2f(request) bind(C, name="MPI_Request_c2f") integer(c_int) :: c_mpi_request_c2f end function + function c_mpi_request_f2c(request) bind(C, name="MPI_Request_f2c") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: request + integer(kind=MPI_HANDLE_KIND) :: c_mpi_request_f2c + end function c_mpi_request_f2c + function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: datatype @@ -44,13 +50,18 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") integer(c_int) :: f_status(*) ! assumed-size array integer(c_int) :: c_mpi_status_c2f end function c_mpi_status_c2f - + function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: info_f integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c end function c_mpi_info_f2c + function c_mpi_statuses_ignore() bind(C, name="get_c_MPI_STATUSES_IGNORE") + use iso_c_binding, only: c_ptr + type(c_ptr) :: c_mpi_statuses_ignore + end function c_mpi_statuses_ignore + function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran") use iso_c_binding, only: c_double, c_ptr real(c_double), value :: in_place_f @@ -189,13 +200,13 @@ function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, na integer(c_int) :: c_mpi_recv end function c_mpi_recv - subroutine c_mpi_waitall(count, array_of_requests, array_of_statuses, ierror) bind(C, name="mpi_waitall_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(in) :: count - integer(c_int), intent(inout) :: array_of_requests(count) - integer(c_int) :: array_of_statuses(*) - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_waitall(count, requests, statuses) bind(C, name="MPI_Waitall") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: count + integer(kind=MPI_HANDLE_KIND), dimension(*), intent(inout) :: requests + type(c_ptr), value :: statuses + integer(c_int) :: c_mpi_waitall + end function c_mpi_waitall function c_mpi_ssend(buf, count, datatype, dest, tag, comm) bind(C, name="MPI_Ssend") use iso_c_binding, only: c_int, c_double, c_ptr diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index e8fbe87..7e29380 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -61,29 +61,6 @@ void* get_c_mpi_inplace_from_fortran(double sendbuf) { return MPI_IN_PLACE; } -void mpi_waitall_wrapper(int *count, int *array_of_requests_f, - int *array_of_statuses_f, int *ierror) { - MPI_Request *array_of_requests; - MPI_Status *array_of_statuses; - array_of_requests = (MPI_Request *)malloc((*count) * sizeof(MPI_Request)); - array_of_statuses = (MPI_Status *)malloc((*count) * sizeof(MPI_Status)); - if (array_of_requests == NULL || array_of_statuses == NULL) { - *ierror = MPI_ERR_NO_MEM; - return; - } - for (int i = 0; i < *count; i++) { - array_of_requests[i] = MPI_Request_f2c(array_of_requests_f[i]); - } - - *ierror = MPI_Waitall(*count, array_of_requests, array_of_statuses); - for (int i = 0; i < *count; i++) { - array_of_requests_f[i] = MPI_Request_c2f(array_of_requests[i]); - } - - for (int i = 0; i < *count; i++) { - MPI_Status_c2f(&array_of_statuses[i], &array_of_statuses_f[i * MPI_STATUS_SIZE]); - } - - free(array_of_requests); - free(array_of_statuses); -} +MPI_Status* get_c_MPI_STATUSES_IGNORE(){ + return MPI_STATUSES_IGNORE; +} \ No newline at end of file