diff --git a/src/mpi.f90 b/src/mpi.f90 index 86c54ef..7a6bed4 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -137,6 +137,16 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm end if end function handle_mpi_comm_f2c + integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info) + use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null + integer, intent(in) :: info_f + if (info_f == MPI_INFO_NULL) then + c_info = c_mpi_info_null() + else + c_info = c_mpi_info_f2c(info_f) + end if + end function handle_mpi_info_f2c + subroutine MPI_Init_proc(ierr) use mpi_c_bindings, only: c_mpi_init use iso_c_binding, only : c_int, c_ptr, c_null_ptr @@ -439,7 +449,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then - sendbuf_ptr = c_mpi_in_place_f2c(sendbuf) + sendbuf_ptr = c_mpi_in_place_f2c() else sendbuf_ptr = c_loc(sendbuf) end if @@ -472,7 +482,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then - sendbuf_ptr = c_mpi_in_place_f2c(sendbuf) + sendbuf_ptr = c_mpi_in_place_f2c() else sendbuf_ptr = c_loc(sendbuf) end if @@ -603,7 +613,7 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror) use iso_c_binding, only: c_int, c_ptr - use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_c2f, c_mpi_info_f2c + use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_c2f integer, intent(in) :: comm integer, intent(in) :: split_type, key integer, intent(in) :: info @@ -614,21 +624,21 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror integer(kind=MPI_HANDLE_KIND) :: c_comm, c_info, c_new_comm c_comm = handle_mpi_comm_f2c(comm) - c_info = c_mpi_info_f2c(info) + c_info = handle_mpi_info_f2c(info) + ! Call the native MPI_Comm_split_type. local_ierr = c_mpi_comm_split_type(c_comm, split_type, key, c_info, c_new_comm) - + ! Convert the new communicator C handle back to a Fortran integer handle. newcomm = c_mpi_comm_c2f(c_new_comm) if (present(ierror)) then - ierror = local_ierr + ierror = local_ierr else - if (local_ierr /= 0) then - print *, "MPI_Comm_split_type failed with error code: ", local_ierr - end if + if (local_ierr /= 0) then + print *, "MPI_Comm_split_type failed with error code: ", local_ierr + end if end if - end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index bdcfd8b..2c0fa2f 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -39,12 +39,16 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") integer(c_int) :: f_status(*) ! assumed-size array integer(c_int) :: c_mpi_status_c2f end function c_mpi_status_c2f - - function c_mpi_comm_world() bind(C, name="get_c_mpi_comm_world") + + function c_mpi_comm_world() bind(C, name="get_c_MPI_COMM_WORLD") use iso_c_binding, only: c_ptr integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_world end function c_mpi_comm_world + function c_mpi_info_null() bind(C, name="get_c_MPI_INFO_NULL") + integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_null + end function c_mpi_info_null + function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: datatype @@ -56,8 +60,8 @@ function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") integer(c_int), value :: op_f integer(kind=MPI_HANDLE_KIND) :: c_mpi_op_f2c end function c_mpi_op_f2c - - function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") + + function c_mpi_info_f2c(info_f) bind(C, name="MPI_Info_f2c") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: info_f integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c @@ -68,9 +72,8 @@ function c_mpi_statuses_ignore() bind(C, name="get_c_MPI_STATUSES_IGNORE") type(c_ptr) :: c_mpi_statuses_ignore end function c_mpi_statuses_ignore - function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran") - use iso_c_binding, only: c_double, c_ptr - real(c_double), value :: in_place_f + function c_mpi_in_place_f2c() bind(C,name="get_c_MPI_IN_PLACE") + use iso_c_binding, only: c_ptr type(c_ptr) :: c_mpi_in_place_f2c end function c_mpi_in_place_f2c diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 53985cc..d37b6cb 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -33,12 +33,8 @@ MPI_Datatype get_c_datatype_from_fortran(int datatype) { return c_datatype; } -MPI_Info get_c_info_from_fortran(int info) { - if (info == FORTRAN_MPI_INFO_NULL) { - return MPI_INFO_NULL; - } else { - return MPI_Info_f2c(info); - } +MPI_Info get_c_MPI_INFO_NULL() { + return MPI_INFO_NULL; } MPI_Op get_c_op_from_fortran(int op) { @@ -49,14 +45,14 @@ MPI_Op get_c_op_from_fortran(int op) { } } -MPI_Comm get_c_mpi_comm_world() { +MPI_Comm get_c_MPI_COMM_WORLD() { return MPI_COMM_WORLD; } -void* get_c_mpi_inplace_from_fortran(double sendbuf) { +void* get_c_MPI_IN_PLACE() { return MPI_IN_PLACE; } -MPI_Status* get_c_MPI_STATUSES_IGNORE(){ +MPI_Status* get_c_MPI_STATUSES_IGNORE() { return MPI_STATUSES_IGNORE; -} \ No newline at end of file +}