diff --git a/src/mpi.f90 b/src/mpi.f90 index 5996808..5dee153 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -131,7 +131,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op) use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum integer, intent(in) :: op_f if (op_f == MPI_SUM) then - c_op = c_mpi_sum() + c_op = c_mpi_sum else c_op = c_mpi_op_f2c(op_f) end if @@ -141,7 +141,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world integer, intent(in) :: comm_f if (comm_f == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() + c_comm = c_mpi_comm_world else c_comm = c_mpi_comm_f2c(comm_f) end if @@ -151,7 +151,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null integer, intent(in) :: info_f if (info_f == MPI_INFO_NULL) then - c_info = c_mpi_info_null() + c_info = c_mpi_info_null else c_info = c_mpi_info_f2c(info_f) end if @@ -161,11 +161,11 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) resul use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int integer, intent(in) :: datatype_f if (datatype_f == MPI_REAL4) then - c_datatype = c_mpi_float() + c_datatype = c_mpi_float else if (datatype_f == MPI_REAL8 .OR. datatype_f == MPI_DOUBLE_PRECISION) then - c_datatype = c_mpi_double() + c_datatype = c_mpi_double else if (datatype_f == MPI_INTEGER) then - c_datatype = c_mpi_int() + c_datatype = c_mpi_int end if end function @@ -280,7 +280,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) end subroutine MPI_Bcast_int_scalar subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c use iso_c_binding, only: c_int, c_ptr, c_loc real(8), dimension(:, :), target :: buffer integer, intent(in) :: count, root @@ -461,7 +461,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -471,7 +471,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then - sendbuf_ptr = c_mpi_in_place_f2c() + sendbuf_ptr = c_mpi_in_place else sendbuf_ptr = c_loc(sendbuf) end if @@ -494,7 +494,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -504,7 +504,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then - sendbuf_ptr = c_mpi_in_place_f2c() + sendbuf_ptr = c_mpi_in_place else sendbuf_ptr = c_loc(sendbuf) end if @@ -528,7 +528,7 @@ end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -557,8 +557,7 @@ end subroutine MPI_Allreduce_1D_real_proc subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, & - c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_comm_f2c integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -754,7 +753,7 @@ subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror) integer(kind=MPI_HANDLE_KIND), dimension(count) :: c_requests type(c_ptr) :: MPI_STATUSES_IGNORE_from_c - MPI_STATUSES_IGNORE_from_c = c_mpi_statuses_ignore() + MPI_STATUSES_IGNORE_from_c = c_mpi_statuses_ignore ! Convert Fortran requests to C requests. do i = 1, count @@ -859,7 +858,7 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c, c_mpi_comm_world + use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c integer, intent(in) :: comm integer, intent(in) :: direction, disp integer, intent(out) :: rank_source, rank_dest @@ -867,11 +866,7 @@ subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ie integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - if (comm == MPI_COMM_WORLD) then - c_comm = c_mpi_comm_world() - else - c_comm = c_mpi_comm_f2c(comm) - end if + c_comm = handle_mpi_comm_f2c(comm) local_ierr = c_mpi_cart_shift(c_comm, direction, disp, rank_source, rank_dest) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 283f1a1..eddcc06 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -1,4 +1,5 @@ module mpi_c_bindings + use iso_c_binding, only: c_ptr implicit none #ifdef OPEN_MPI @@ -7,8 +8,17 @@ module mpi_c_bindings #define MPI_HANDLE_KIND 4 #endif + type(c_ptr), bind(C, name="c_MPI_STATUSES_IGNORE") :: c_mpi_statuses_ignore + type(c_ptr), bind(C, name="c_MPI_IN_PLACE") :: c_mpi_in_place + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INFO_NULL") :: c_mpi_info_null + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_DOUBLE") :: c_mpi_double + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_FLOAT") :: c_mpi_float + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum + interface - + function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: comm_f @@ -40,31 +50,6 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") integer(c_int) :: c_mpi_status_c2f end function c_mpi_status_c2f - function c_mpi_comm_world() bind(C, name="get_c_MPI_COMM_WORLD") - use iso_c_binding, only: c_ptr - integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_world - end function c_mpi_comm_world - - function c_mpi_info_null() bind(C, name="get_c_MPI_INFO_NULL") - integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_null - end function c_mpi_info_null - - function c_mpi_sum() bind(C, name="get_c_MPI_SUM") - integer(kind=MPI_HANDLE_KIND) :: c_mpi_sum - end function c_mpi_sum - - function c_mpi_float() bind(C, name="get_c_MPI_FLOAT") - integer(kind=MPI_HANDLE_KIND) :: c_mpi_float - end function - - function c_mpi_double() bind(C, name="get_c_MPI_DOUBLE") - integer(kind=MPI_HANDLE_KIND) :: c_mpi_double - end function - - function c_mpi_int() bind(C, name="get_c_MPI_INT") - integer(kind=MPI_HANDLE_KIND) :: c_mpi_int - end function - function c_mpi_op_f2c(op_f) bind(C, name="MPI_Op_f2c") use iso_c_binding, only: c_ptr, c_int integer(c_int), value :: op_f @@ -77,16 +62,6 @@ function c_mpi_info_f2c(info_f) bind(C, name="MPI_Info_f2c") integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c end function c_mpi_info_f2c - function c_mpi_statuses_ignore() bind(C, name="get_c_MPI_STATUSES_IGNORE") - use iso_c_binding, only: c_ptr - type(c_ptr) :: c_mpi_statuses_ignore - end function c_mpi_statuses_ignore - - function c_mpi_in_place_f2c() bind(C,name="get_c_MPI_IN_PLACE") - use iso_c_binding, only: c_ptr - type(c_ptr) :: c_mpi_in_place_f2c - end function c_mpi_in_place_f2c - function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 4e11a30..66502f1 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -1,33 +1,17 @@ #include -MPI_Datatype get_c_MPI_DOUBLE() { - return MPI_DOUBLE; -} +MPI_Status* c_MPI_STATUSES_IGNORE = MPI_STATUSES_IGNORE; -MPI_Datatype get_c_MPI_FLOAT() { - return MPI_FLOAT; -} +MPI_Info c_MPI_INFO_NULL = MPI_INFO_NULL; -MPI_Datatype get_c_MPI_INT() { - return MPI_INT; -} +MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD; -MPI_Info get_c_MPI_INFO_NULL() { - return MPI_INFO_NULL; -} +MPI_Datatype c_MPI_DOUBLE = MPI_DOUBLE; -MPI_Op get_c_MPI_SUM() { - return MPI_SUM; -} +MPI_Datatype c_MPI_FLOAT = MPI_FLOAT; -MPI_Comm get_c_MPI_COMM_WORLD() { - return MPI_COMM_WORLD; -} +MPI_Datatype c_MPI_INT = MPI_INT; -void* get_c_MPI_IN_PLACE() { - return MPI_IN_PLACE; -} +void* c_MPI_IN_PLACE = MPI_IN_PLACE; -MPI_Status* get_c_MPI_STATUSES_IGNORE() { - return MPI_STATUSES_IGNORE; -} +MPI_Op c_MPI_SUM = MPI_SUM;