diff --git a/src/mpi.f90 b/src/mpi.f90 index 58edb1c..5996808 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -157,6 +157,18 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info end if end function handle_mpi_info_f2c + integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) result(c_datatype) + use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int + integer, intent(in) :: datatype_f + if (datatype_f == MPI_REAL4) then + c_datatype = c_mpi_float() + else if (datatype_f == MPI_REAL8 .OR. datatype_f == MPI_DOUBLE_PRECISION) then + c_datatype = c_mpi_double() + else if (datatype_f == MPI_INTEGER) then + c_datatype = c_mpi_int() + end if + end function + subroutine MPI_Init_proc(ierr) use mpi_c_bindings, only: c_mpi_init use iso_c_binding, only : c_int, c_ptr, c_null_ptr @@ -241,7 +253,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) end subroutine subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_bcast use iso_c_binding, only: c_int, c_ptr, c_loc integer, target :: buffer integer, intent(in) :: count, root @@ -254,7 +266,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) c_comm = handle_mpi_comm_f2c(comm) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) buffer_ptr = c_loc(buffer) local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) @@ -268,7 +280,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) end subroutine MPI_Bcast_int_scalar subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_comm_world use iso_c_binding, only: c_int, c_ptr, c_loc real(8), dimension(:, :), target :: buffer integer, intent(in) :: count, root @@ -281,7 +293,7 @@ subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) c_comm = handle_mpi_comm_f2c(comm) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) buffer_ptr = c_loc(buffer) local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) @@ -296,7 +308,7 @@ end subroutine MPI_Bcast_real_2D subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allgather_int, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_allgather_int integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:, :), intent(out), target :: recvbuf integer, intent(in) :: sendcount, recvcount @@ -310,8 +322,8 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r c_comm = handle_mpi_comm_f2c(comm) - c_sendtype = c_mpi_datatype_f2c(sendtype) - c_recvtype = c_mpi_datatype_f2c(recvtype) + c_sendtype = handle_mpi_datatype_f2c(sendtype) + c_recvtype = handle_mpi_datatype_f2c(recvtype) sendbuf_ptr = c_loc(sendbuf) recvbuf_ptr = c_loc(recvbuf) local_ierr = c_mpi_allgather_int(sendbuf_ptr, sendcount, c_sendtype, recvbuf_ptr, recvcount, c_recvtype, c_comm) @@ -327,7 +339,7 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allgather_real, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_allgather_real real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:, :), intent(out), target :: recvbuf integer, intent(in) :: sendcount, recvcount @@ -341,8 +353,8 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, c_comm = handle_mpi_comm_f2c(comm) - c_sendtype = c_mpi_datatype_f2c(sendtype) - c_recvtype = c_mpi_datatype_f2c(recvtype) + c_sendtype = handle_mpi_datatype_f2c(sendtype) + c_recvtype = handle_mpi_datatype_f2c(recvtype) sendbuf_ptr = c_loc(sendbuf) recvbuf_ptr = c_loc(recvbuf) local_ierr = c_mpi_allgather_real(sendbuf_ptr, sendcount, c_sendtype, recvbuf_ptr, recvcount, c_recvtype, c_comm) @@ -358,7 +370,7 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_request_c2f + use mpi_c_bindings, only: c_mpi_isend, c_mpi_request_c2f real(8), dimension(:, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -370,7 +382,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) integer(c_int) :: local_ierr buf_ptr = c_loc(buf) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_comm = handle_mpi_comm_f2c(comm) @@ -389,7 +401,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_request_c2f + use mpi_c_bindings, only: c_mpi_isend, c_mpi_request_c2f real(8), dimension(:, :, :), intent(in), target :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -401,7 +413,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) integer(c_int) :: local_ierr buf_ptr = c_loc(buf) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_comm = handle_mpi_comm_f2c(comm) @@ -420,7 +432,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_irecv, c_mpi_datatype_f2c, c_mpi_request_c2f + use mpi_c_bindings, only: c_mpi_irecv, c_mpi_request_c2f real(8), dimension(:,:) :: buf integer, intent(in) :: count, source, tag integer, intent(in) :: datatype @@ -434,7 +446,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr c_comm = handle_mpi_comm_f2c(comm) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) local_ierr = c_mpi_irecv(buf, count, c_datatype, source, tag, c_comm, c_request) request = c_mpi_request_c2f(c_request) @@ -449,7 +461,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -464,7 +476,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier sendbuf_ptr = c_loc(sendbuf) end if recvbuf_ptr = c_loc(recvbuf) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -482,7 +494,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -498,7 +510,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com end if recvbuf_ptr = c_loc(recvbuf) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -516,7 +528,7 @@ end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -527,7 +539,7 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com sendbuf_ptr = c_loc(sendbuf) recvbuf_ptr = c_loc(recvbuf) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -545,7 +557,7 @@ end subroutine MPI_Allreduce_1D_real_proc subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, & + use mpi_c_bindings, only: c_mpi_allreduce, & c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:), intent(out), target :: recvbuf @@ -557,7 +569,7 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm sendbuf_ptr = c_loc(sendbuf) recvbuf_ptr = c_loc(recvbuf) - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -653,7 +665,7 @@ end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_recv, c_mpi_datatype_f2c, c_mpi_status_c2f + use mpi_c_bindings, only: c_mpi_recv, c_mpi_status_c2f real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status(MPI_STATUS_SIZE) @@ -665,7 +677,7 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status ! Convert Fortran handles to C handles. - c_dtype = c_mpi_datatype_f2c(datatype) + c_dtype = handle_mpi_datatype_f2c(datatype) c_comm = handle_mpi_comm_f2c(comm) @@ -690,7 +702,7 @@ end subroutine MPI_Recv_StatusArray_proc subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_recv, c_mpi_datatype_f2c, c_mpi_status_c2f + use mpi_c_bindings, only: c_mpi_recv, c_mpi_status_c2f real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status @@ -702,7 +714,7 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status ! Convert Fortran handles to C handles. - c_dtype = c_mpi_datatype_f2c(datatype) + c_dtype = handle_mpi_datatype_f2c(datatype) c_comm = handle_mpi_comm_f2c(comm) @@ -769,7 +781,7 @@ end subroutine MPI_Waitall_proc subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_ssend, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_ssend real(8), dimension(*), intent(in) :: buf integer, intent(in) :: count, dest, tag integer, intent(in) :: datatype @@ -778,7 +790,7 @@ subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_comm integer :: local_ierr - c_datatype = c_mpi_datatype_f2c(datatype) + c_datatype = handle_mpi_datatype_f2c(datatype) c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_ssend(buf, count, c_datatype, dest, tag, c_comm) end subroutine @@ -924,7 +936,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) end subroutine subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c + use mpi_c_bindings, only: c_mpi_reduce use iso_c_binding, only: c_int, c_ptr, c_loc integer, target, intent(in) :: sendbuf integer, target, intent(out) :: recvbuf @@ -937,7 +949,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co c_comm = handle_mpi_comm_f2c(comm) - c_dtype = c_mpi_datatype_f2c(datatype) + c_dtype = handle_mpi_datatype_f2c(datatype) c_op = handle_mpi_op_f2c(op) ! Pass pointer to the actual data diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index f2fb558..283f1a1 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -53,11 +53,17 @@ function c_mpi_sum() bind(C, name="get_c_MPI_SUM") integer(kind=MPI_HANDLE_KIND) :: c_mpi_sum end function c_mpi_sum - function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") - use iso_c_binding, only: c_int, c_ptr - integer(c_int), value :: datatype - integer(kind=MPI_HANDLE_KIND) :: c_mpi_datatype_f2c - end function c_mpi_datatype_f2c + function c_mpi_float() bind(C, name="get_c_MPI_FLOAT") + integer(kind=MPI_HANDLE_KIND) :: c_mpi_float + end function + + function c_mpi_double() bind(C, name="get_c_MPI_DOUBLE") + integer(kind=MPI_HANDLE_KIND) :: c_mpi_double + end function + + function c_mpi_int() bind(C, name="get_c_MPI_INT") + integer(kind=MPI_HANDLE_KIND) :: c_mpi_int + end function function c_mpi_op_f2c(op_f) bind(C, name="MPI_Op_f2c") use iso_c_binding, only: c_ptr, c_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index e61ea52..4e11a30 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -1,36 +1,15 @@ #include -#include -#include - -#define MPI_STATUS_SIZE 5 - -#define FORTRAN_MPI_COMM_WORLD -1000 -#define FORTRAN_MPI_INFO_NULL -2000 -#define FORTRAN_MPI_IN_PLACE -1002 - -#define FORTRAN_MPI_SUM -2300 - -#define FORTRAN_MPI_INTEGER -10002 -#define FORTRAN_MPI_DOUBLE_PRECISION -10004 -#define FORTRAN_MPI_REAL4 -10013 -#define FORTRAN_MPI_REAL8 -10014 - - -MPI_Datatype get_c_datatype_from_fortran(int datatype) { - MPI_Datatype c_datatype; - switch (datatype) { - case FORTRAN_MPI_REAL4: - c_datatype = MPI_FLOAT; - break; - case FORTRAN_MPI_REAL8: - case FORTRAN_MPI_DOUBLE_PRECISION: - c_datatype = MPI_DOUBLE; - break; - case FORTRAN_MPI_INTEGER: - c_datatype = MPI_INT; - break; - } - return c_datatype; + +MPI_Datatype get_c_MPI_DOUBLE() { + return MPI_DOUBLE; +} + +MPI_Datatype get_c_MPI_FLOAT() { + return MPI_FLOAT; +} + +MPI_Datatype get_c_MPI_INT() { + return MPI_INT; } MPI_Info get_c_MPI_INFO_NULL() {