From d171f13b55ef0b46b28d4b00e053ebd62b6c7bc1 Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Thu, 10 Apr 2025 22:32:03 +0530 Subject: [PATCH 1/4] Remove get_c_comm_from_fortran and keep just get_c_mpi_comm_world --- src/mpi_wrapper.c | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 7e29380..53985cc 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -49,12 +49,8 @@ MPI_Op get_c_op_from_fortran(int op) { } } -MPI_Comm get_c_comm_from_fortran(int comm_f) { - if (comm_f == FORTRAN_MPI_COMM_WORLD) { - return MPI_COMM_WORLD; - } else { - return MPI_Comm_f2c(comm_f); - } +MPI_Comm get_c_mpi_comm_world() { + return MPI_COMM_WORLD; } void* get_c_mpi_inplace_from_fortran(double sendbuf) { From 8b3c748f7322f87f35d41775de3d60b532bf0c26 Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Thu, 10 Apr 2025 22:32:14 +0530 Subject: [PATCH 2/4] Use MPI_Comm_f2c from fortran bind(C) interface --- src/mpi.f90 | 206 ++++++++++++++++++++++++++++++++--------- src/mpi_c_bindings.f90 | 22 +++-- 2 files changed, 174 insertions(+), 54 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 81a5bcc..90201ec 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -190,7 +190,7 @@ subroutine MPI_Finalize_proc(ierr) end subroutine subroutine MPI_Comm_size_proc(comm, size, ierror) - use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out) :: size @@ -198,7 +198,12 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) integer :: local_ierr integer(kind=MPI_HANDLE_KIND) :: c_comm - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_comm_size(c_comm, size) if (present(ierror)) then ierror = local_ierr @@ -210,7 +215,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) end subroutine subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world use iso_c_binding, only: c_int, c_ptr, c_loc integer, target :: buffer integer, intent(in) :: count, root @@ -221,7 +226,12 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) integer :: local_ierr type(c_ptr) :: buffer_ptr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_datatype = c_mpi_datatype_f2c(datatype) buffer_ptr = c_loc(buffer) local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) @@ -236,7 +246,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) end subroutine MPI_Bcast_int_scalar subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world use iso_c_binding, only: c_int, c_ptr, c_loc real(8), dimension(:, :), target :: buffer integer, intent(in) :: count, root @@ -247,7 +257,12 @@ subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) integer :: local_ierr type(c_ptr) :: buffer_ptr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_datatype = c_mpi_datatype_f2c(datatype) buffer_ptr = c_loc(buffer) local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) @@ -263,7 +278,7 @@ end subroutine MPI_Bcast_real_2D subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allgather_int, c_mpi_comm_f2c, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_allgather_int, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:, :), intent(out), target :: recvbuf integer, intent(in) :: sendcount, recvcount @@ -275,7 +290,12 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_sendtype = c_mpi_datatype_f2c(sendtype) c_recvtype = c_mpi_datatype_f2c(recvtype) sendbuf_ptr = c_loc(sendbuf) @@ -293,7 +313,7 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allgather_real, c_mpi_comm_f2c, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_allgather_real, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:, :), intent(out), target :: recvbuf integer, intent(in) :: sendcount, recvcount @@ -305,7 +325,12 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_sendtype = c_mpi_datatype_f2c(sendtype) c_recvtype = c_mpi_datatype_f2c(recvtype) sendbuf_ptr = c_loc(sendbuf) @@ -323,7 +348,7 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f + use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f, c_mpi_comm_world real(8), dimension(:, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -336,7 +361,13 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) buf_ptr = c_loc(buf) c_datatype = c_mpi_datatype_f2c(datatype) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_isend(buf_ptr, count, c_datatype, dest, tag, c_comm, c_request) if (present(ierror)) then @@ -352,7 +383,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f + use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f, c_mpi_comm_world real(8), dimension(:, :, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -365,7 +396,13 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) buf_ptr = c_loc(buf) c_datatype = c_mpi_datatype_f2c(datatype) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_isend(buf_ptr, count, c_datatype, dest, tag, c_comm, c_request) if (present(ierror)) then @@ -381,7 +418,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_irecv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_request_c2f + use mpi_c_bindings, only: c_mpi_irecv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_request_c2f, c_mpi_comm_world real(8), dimension(:,:) :: buf integer, intent(in) :: count, source, tag integer, intent(in) :: datatype @@ -393,7 +430,12 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr integer(kind=MPI_HANDLE_KIND) :: c_datatype integer(kind=MPI_HANDLE_KIND) :: c_request - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_datatype = c_mpi_datatype_f2c(datatype) local_ierr = c_mpi_irecv(buf, count, c_datatype, source, tag, c_comm, c_request) request = c_mpi_request_c2f(c_request) @@ -409,7 +451,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -426,7 +468,12 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -441,7 +488,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -458,7 +505,12 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -473,7 +525,7 @@ end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c,c_mpi_comm_world real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -486,7 +538,12 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -501,7 +558,7 @@ end subroutine MPI_Allreduce_1D_real_proc subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -514,7 +571,12 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -534,7 +596,7 @@ function MPI_Wtime_proc() result(time) end function subroutine MPI_Barrier_proc(comm, ierror) - use mpi_c_bindings, only: c_mpi_barrier, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_barrier, c_mpi_comm_f2c, c_mpi_comm_world use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out), optional :: ierror @@ -542,7 +604,12 @@ subroutine MPI_Barrier_proc(comm, ierror) integer(c_int) :: local_ierr ! Convert Fortran handle to C handle - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_barrier(c_comm) if (present(ierror)) then @@ -556,14 +623,19 @@ end subroutine MPI_Barrier_proc subroutine MPI_Comm_rank_proc(comm, rank, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_comm_rank, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_comm_rank, c_mpi_comm_f2c, c_mpi_comm_world integer, intent(in) :: comm integer, intent(out) :: rank integer, optional, intent(out) :: ierror integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_comm_rank(c_comm, rank) if (present(ierror)) then @@ -577,7 +649,7 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_info_f2c + use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_info_f2c, c_mpi_comm_world integer, intent(in) :: comm integer, intent(in) :: split_type, key integer, intent(in) :: info @@ -588,7 +660,12 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror integer(kind=MPI_HANDLE_KIND) :: c_comm, c_info, c_new_comm ! Convert Fortran communicator and info handles to C pointers. - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_info = c_mpi_info_f2c(info) ! Call the native MPI_Comm_split_type. @@ -609,7 +686,7 @@ end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f + use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f, c_mpi_comm_world real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status(MPI_STATUS_SIZE) @@ -622,7 +699,12 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st ! Convert Fortran handles to C handles. c_dtype = c_mpi_datatype_f2c(datatype) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if ! Use a local temporary MPI_Status (as an array of c_int) c_status = c_loc(tmp_status) @@ -645,7 +727,7 @@ end subroutine MPI_Recv_StatusArray_proc subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f + use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f, c_mpi_comm_world real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status @@ -658,7 +740,12 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s ! Convert Fortran handles to C handles. c_dtype = c_mpi_datatype_f2c(datatype) - c_comm = c_mpi_comm_f2c(comm) + + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if ! Use a local temporary MPI_Status (as an array of c_int) c_status = c_loc(tmp_status) @@ -723,7 +810,7 @@ end subroutine MPI_Waitall_proc subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_ssend, c_mpi_datatype_f2c, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_ssend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_comm_world real(8), dimension(*), intent(in) :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -733,13 +820,17 @@ subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) integer :: local_ierr c_datatype = c_mpi_datatype_f2c(datatype) - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if local_ierr = c_mpi_ssend(buf, count, c_datatype, dest, tag, c_comm) end subroutine subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_cart, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_cart_create, c_mpi_comm_f2c, c_mpi_comm_c2f + use mpi_c_bindings, only: c_mpi_cart_create, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_comm_world integer, intent(in) :: ndims, dims(ndims) logical, intent(in) :: periods(ndims), reorder integer, intent(in) :: comm_old @@ -750,7 +841,11 @@ subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_ca integer(kind=MPI_HANDLE_KIND) :: c_comm_cart integer(c_int) :: local_ierr - c_comm_old = c_mpi_comm_f2c(comm_old) + if (comm_old == MPI_COMM_WORLD) then + c_comm_old= c_mpi_comm_world() + else + c_comm_old = c_mpi_comm_f2c(comm_old) + end if ndims_c = ndims if (reorder) then reorder_c = 1 @@ -777,7 +872,7 @@ subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_ca subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_cart_coords, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_cart_coords, c_mpi_comm_f2c, c_mpi_comm_world integer, intent(in) :: comm integer, intent(in) :: rank, maxdims integer, intent(out) :: coords(maxdims) @@ -785,7 +880,12 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_cart_coords(c_comm, rank, maxdims, coords) if (present(ierror)) then @@ -799,7 +899,7 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c + use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c, c_mpi_comm_world integer, intent(in) :: comm integer, intent(in) :: direction, disp integer, intent(out) :: rank_source, rank_dest @@ -807,7 +907,12 @@ subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ie integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + local_ierr = c_mpi_cart_shift(c_comm, direction, disp, rank_source, rank_dest) if (present(ierror)) then @@ -839,7 +944,7 @@ subroutine MPI_Dims_create_proc(nnodes, ndims, dims, ierror) subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_cart_sub, c_mpi_comm_f2c, c_mpi_comm_c2f + use mpi_c_bindings, only: c_mpi_cart_sub, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_comm_world integer, intent(in) :: comm logical, intent(in) :: remain_dims(:) integer, intent(out) :: newcomm @@ -849,7 +954,11 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) integer :: local_ierr type(c_ptr) :: remain_dims_i_ptr - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if where (remain_dims) remain_dims_i = 1 @@ -871,7 +980,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) end subroutine subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_reduce, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_op_f2c + use mpi_c_bindings, only: c_mpi_reduce, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_world use iso_c_binding, only: c_int, c_ptr, c_loc integer, target, intent(in) :: sendbuf integer, target, intent(out) :: recvbuf @@ -883,7 +992,12 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co integer(c_int) :: local_ierr ! Convert Fortran integer handles => C pointers - c_comm = c_mpi_comm_f2c(comm) + if (comm == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm) + end if + c_dtype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 7738aab..bdcfd8b 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -8,7 +8,8 @@ module mpi_c_bindings #endif interface - function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran") + + function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: comm_f integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_f2c @@ -32,6 +33,18 @@ function c_mpi_request_f2c(request) bind(C, name="MPI_Request_f2c") integer(kind=MPI_HANDLE_KIND) :: c_mpi_request_f2c end function c_mpi_request_f2c + function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") + use iso_c_binding, only: c_ptr, c_int + type(c_ptr) :: c_status + integer(c_int) :: f_status(*) ! assumed-size array + integer(c_int) :: c_mpi_status_c2f + end function c_mpi_status_c2f + + function c_mpi_comm_world() bind(C, name="get_c_mpi_comm_world") + use iso_c_binding, only: c_ptr + integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_world + end function c_mpi_comm_world + function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: datatype @@ -43,13 +56,6 @@ function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") integer(c_int), value :: op_f integer(kind=MPI_HANDLE_KIND) :: c_mpi_op_f2c end function c_mpi_op_f2c - - function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") - use iso_c_binding, only: c_ptr, c_int - type(c_ptr) :: c_status - integer(c_int) :: f_status(*) ! assumed-size array - integer(c_int) :: c_mpi_status_c2f - end function c_mpi_status_c2f function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") use iso_c_binding, only: c_int, c_ptr From 208bfe936e1a8202724669ce11eef5ebc628fa9e Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Thu, 10 Apr 2025 22:38:58 +0530 Subject: [PATCH 3/4] Fix Line Truncation error --- src/mpi.f90 | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 90201ec..244a6c6 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -451,7 +451,8 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & + c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -488,7 +489,8 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & + c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -525,7 +527,8 @@ end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c,c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & + c_mpi_comm_f2c, c_mpi_in_place_f2c,c_mpi_comm_world real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -558,7 +561,8 @@ end subroutine MPI_Allreduce_1D_real_proc subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & + c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -649,7 +653,8 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_info_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, & + c_mpi_info_f2c, c_mpi_comm_world integer, intent(in) :: comm integer, intent(in) :: split_type, key integer, intent(in) :: info From 90e6a989d8f1db299a6cc3065fd0dd243da6283c Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Fri, 11 Apr 2025 15:37:18 +0530 Subject: [PATCH 4/4] use a function to get integer handle for MPI_Comm_f2c --- src/mpi.f90 | 240 +++++++++++++++++----------------------------------- 1 file changed, 77 insertions(+), 163 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 244a6c6..86c54ef 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -125,7 +125,17 @@ module mpi module procedure MPI_Reduce_scalar_int end interface - contains + contains + + integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm) + use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world + integer, intent(in) :: comm_f + if (comm_f == MPI_COMM_WORLD) then + c_comm = c_mpi_comm_world() + else + c_comm = c_mpi_comm_f2c(comm_f) + end if + end function handle_mpi_comm_f2c subroutine MPI_Init_proc(ierr) use mpi_c_bindings, only: c_mpi_init @@ -190,7 +200,7 @@ subroutine MPI_Finalize_proc(ierr) end subroutine subroutine MPI_Comm_size_proc(comm, size, ierror) - use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_comm_size use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out) :: size @@ -198,11 +208,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) integer :: local_ierr integer(kind=MPI_HANDLE_KIND) :: c_comm - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_comm_size(c_comm, size) if (present(ierror)) then @@ -215,7 +221,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) end subroutine subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_datatype_f2c use iso_c_binding, only: c_int, c_ptr, c_loc integer, target :: buffer integer, intent(in) :: count, root @@ -226,11 +232,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) integer :: local_ierr type(c_ptr) :: buffer_ptr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_datatype = c_mpi_datatype_f2c(datatype) buffer_ptr = c_loc(buffer) @@ -257,11 +259,7 @@ subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) integer :: local_ierr type(c_ptr) :: buffer_ptr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_datatype = c_mpi_datatype_f2c(datatype) buffer_ptr = c_loc(buffer) @@ -278,7 +276,7 @@ end subroutine MPI_Bcast_real_2D subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allgather_int, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allgather_int, c_mpi_datatype_f2c integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:, :), intent(out), target :: recvbuf integer, intent(in) :: sendcount, recvcount @@ -290,11 +288,7 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_sendtype = c_mpi_datatype_f2c(sendtype) c_recvtype = c_mpi_datatype_f2c(recvtype) @@ -313,7 +307,7 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allgather_real, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allgather_real, c_mpi_datatype_f2c real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:, :), intent(out), target :: recvbuf integer, intent(in) :: sendcount, recvcount @@ -325,11 +319,7 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_sendtype = c_mpi_datatype_f2c(sendtype) c_recvtype = c_mpi_datatype_f2c(recvtype) @@ -348,7 +338,7 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_request_c2f real(8), dimension(:, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -361,12 +351,8 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) buf_ptr = c_loc(buf) c_datatype = c_mpi_datatype_f2c(datatype) - - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_isend(buf_ptr, count, c_datatype, dest, tag, c_comm, c_request) @@ -383,7 +369,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_request_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_request_c2f real(8), dimension(:, :, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -396,12 +382,8 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) buf_ptr = c_loc(buf) c_datatype = c_mpi_datatype_f2c(datatype) - - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_isend(buf_ptr, count, c_datatype, dest, tag, c_comm, c_request) @@ -418,7 +400,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_irecv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_request_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_irecv, c_mpi_datatype_f2c, c_mpi_request_c2f real(8), dimension(:,:) :: buf integer, intent(in) :: count, source, tag integer, intent(in) :: datatype @@ -430,11 +412,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr integer(kind=MPI_HANDLE_KIND) :: c_datatype integer(kind=MPI_HANDLE_KIND) :: c_request - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_datatype = c_mpi_datatype_f2c(datatype) local_ierr = c_mpi_irecv(buf, count, c_datatype, source, tag, c_comm, c_request) @@ -451,8 +429,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & - c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -469,12 +446,8 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -489,8 +462,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & - c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -504,15 +476,12 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com else sendbuf_ptr = c_loc(sendbuf) end if + recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -527,8 +496,7 @@ end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & - c_mpi_comm_f2c, c_mpi_in_place_f2c,c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -541,13 +509,9 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) + local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) if (present(ierror)) then @@ -576,11 +540,7 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm c_datatype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op) - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) @@ -600,20 +560,15 @@ function MPI_Wtime_proc() result(time) end function subroutine MPI_Barrier_proc(comm, ierror) - use mpi_c_bindings, only: c_mpi_barrier, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_barrier use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out), optional :: ierror integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - ! Convert Fortran handle to C handle - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if - + c_comm = handle_mpi_comm_f2c(comm) + local_ierr = c_mpi_barrier(c_comm) if (present(ierror)) then @@ -627,19 +582,14 @@ end subroutine MPI_Barrier_proc subroutine MPI_Comm_rank_proc(comm, rank, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_comm_rank, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_comm_rank integer, intent(in) :: comm integer, intent(out) :: rank integer, optional, intent(out) :: ierror integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if - + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_comm_rank(c_comm, rank) if (present(ierror)) then @@ -653,8 +603,7 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, & - c_mpi_info_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_c2f, c_mpi_info_f2c integer, intent(in) :: comm integer, intent(in) :: split_type, key integer, intent(in) :: info @@ -663,16 +612,9 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror integer(c_int) :: local_ierr integer(kind=MPI_HANDLE_KIND) :: c_comm, c_info, c_new_comm - - ! Convert Fortran communicator and info handles to C pointers. - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_info = c_mpi_info_f2c(info) - ! Call the native MPI_Comm_split_type. local_ierr = c_mpi_comm_split_type(c_comm, split_type, key, c_info, c_new_comm) @@ -691,7 +633,7 @@ end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_recv, c_mpi_datatype_f2c, c_mpi_status_c2f real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status(MPI_STATUS_SIZE) @@ -701,27 +643,23 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st integer(kind=MPI_HANDLE_KIND) :: c_dtype, c_comm type(c_ptr) :: c_status integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status - + ! Convert Fortran handles to C handles. c_dtype = c_mpi_datatype_f2c(datatype) - - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if - + + c_comm = handle_mpi_comm_f2c(comm) + ! Use a local temporary MPI_Status (as an array of c_int) c_status = c_loc(tmp_status) - + ! Call the native MPI_Recv. local_ierr = c_mpi_recv(c_loc(buf), count, c_dtype, source, tag, c_comm, c_status) - + ! Convert the C MPI_Status to Fortran status. if (local_ierr == MPI_SUCCESS) then status_ierr = c_mpi_status_c2f(c_status, status) end if - + if (present(ierror)) then ierror = local_ierr else if (local_ierr /= MPI_SUCCESS) then @@ -732,7 +670,7 @@ end subroutine MPI_Recv_StatusArray_proc subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_recv, c_mpi_datatype_f2c, c_mpi_status_c2f real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status @@ -746,11 +684,7 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s ! Convert Fortran handles to C handles. c_dtype = c_mpi_datatype_f2c(datatype) - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) ! Use a local temporary MPI_Status (as an array of c_int) c_status = c_loc(tmp_status) @@ -815,7 +749,7 @@ end subroutine MPI_Waitall_proc subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_ssend, c_mpi_datatype_f2c, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_ssend, c_mpi_datatype_f2c real(8), dimension(*), intent(in) :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -825,17 +759,13 @@ subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) integer :: local_ierr c_datatype = c_mpi_datatype_f2c(datatype) - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_ssend(buf, count, c_datatype, dest, tag, c_comm) end subroutine subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_cart, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_cart_create, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_cart_create, c_mpi_comm_c2f integer, intent(in) :: ndims, dims(ndims) logical, intent(in) :: periods(ndims), reorder integer, intent(in) :: comm_old @@ -846,23 +776,20 @@ subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_ca integer(kind=MPI_HANDLE_KIND) :: c_comm_cart integer(c_int) :: local_ierr - if (comm_old == MPI_COMM_WORLD) then - c_comm_old= c_mpi_comm_world() + c_comm_old = handle_mpi_comm_f2c(comm_old) + + ndims_c = ndims + if (reorder) then + reorder_c = 1 else - c_comm_old = c_mpi_comm_f2c(comm_old) + reorder_c = 0 end if - ndims_c = ndims - if (reorder) then - reorder_c = 1 - else - reorder_c = 0 - end if - dims_c = dims - where (periods) - periods_c = 1 - elsewhere - periods_c = 0 - end where + dims_c = dims + where (periods) + periods_c = 1 + elsewhere + periods_c = 0 + end where local_ierr = c_mpi_cart_create(c_comm_old, ndims, dims_c, periods_c, reorder_c, c_comm_cart) comm_cart = c_mpi_comm_c2f(c_comm_cart) @@ -877,7 +804,7 @@ subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_ca subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_cart_coords, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_cart_coords integer, intent(in) :: comm integer, intent(in) :: rank, maxdims integer, intent(out) :: coords(maxdims) @@ -885,11 +812,7 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_cart_coords(c_comm, rank, maxdims, coords) @@ -949,7 +872,7 @@ subroutine MPI_Dims_create_proc(nnodes, ndims, dims, ierror) subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_cart_sub, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_cart_sub, c_mpi_comm_c2f integer, intent(in) :: comm logical, intent(in) :: remain_dims(:) integer, intent(out) :: newcomm @@ -959,11 +882,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) integer :: local_ierr type(c_ptr) :: remain_dims_i_ptr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) where (remain_dims) remain_dims_i = 1 @@ -985,7 +904,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) end subroutine subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_reduce, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c, c_mpi_op_f2c use iso_c_binding, only: c_int, c_ptr, c_loc integer, target, intent(in) :: sendbuf integer, target, intent(out) :: recvbuf @@ -996,12 +915,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co type(c_ptr) :: c_sendbuf, c_recvbuf integer(c_int) :: local_ierr - ! Convert Fortran integer handles => C pointers - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) c_dtype = c_mpi_datatype_f2c(datatype) c_op = c_mpi_op_f2c(op)