diff --git a/src/mpi.f90 b/src/mpi.f90 index 80ed121..75fef4d 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -448,14 +448,36 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) end subroutine subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror) - use mpi_c_bindings, only: c_mpi_comm_split_type - integer :: comm + use iso_c_binding, only: c_int, c_ptr + use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_info_f2c + integer, intent(in) :: comm integer, intent(in) :: split_type, key integer, intent(in) :: info integer, intent(out) :: newcomm integer, optional, intent(out) :: ierror - call c_mpi_comm_split_type(comm, split_type, key, info, newcomm, ierror) - end subroutine + + integer(c_int) :: local_ierr + type(c_ptr) :: c_comm, c_info, c_new_comm + + ! Convert Fortran communicator and info handles to C pointers. + c_comm = c_mpi_comm_f2c(comm) + c_info = c_mpi_info_f2c(info) + + ! Call the native MPI_Comm_split_type. + local_ierr = c_mpi_comm_split_type(c_comm, split_type, key, c_info, c_new_comm) + + ! Convert the new communicator C handle back to a Fortran integer handle. + newcomm = c_mpi_comm_c2f(c_new_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= 0) then + print *, "MPI_Comm_split_type failed with error code: ", local_ierr + end if + end if + + end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_proc(buf, count, datatype, source, tag, comm, status, ierror) use mpi_c_bindings, only: c_mpi_recv diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index a332b7d..36bdd37 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -32,6 +32,12 @@ function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") type(c_ptr) :: c_mpi_op_f2c end function c_mpi_op_f2c + function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: info_f + type(c_ptr) :: c_mpi_info_f2c + end function c_mpi_info_f2c + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified @@ -168,14 +174,15 @@ function c_mpi_comm_rank(comm, rank) bind(C, name="MPI_Comm_rank") integer(c_int) :: c_mpi_comm_rank end function c_mpi_comm_rank - subroutine c_mpi_comm_split_type(comm, split_type, key, info, newcomm, ierror) bind(C, name="mpi_comm_split_type_wrapper") - use iso_c_binding, only: c_int - integer(c_int) :: comm - integer(c_int), intent(in) :: split_type, key - integer(c_int), intent(in) :: info - integer(c_int), intent(out) :: newcomm - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_comm_split_type(c_comm, split_type, key, c_info, new_comm) bind(C, name="MPI_Comm_split_type") + use iso_c_binding, only: c_ptr, c_int + type(c_ptr), value :: c_comm + integer(c_int), value :: split_type + integer(c_int), value :: key + type(c_ptr), value :: c_info + type(c_ptr) :: new_comm + integer(c_int) :: c_mpi_comm_split_type + end function c_mpi_comm_split_type subroutine c_mpi_recv(buf, count, datatype, source, tag, comm, status, ierror) bind(C, name="mpi_recv_wrapper") use iso_c_binding, only: c_int, c_double diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 525157d..ee8bfec 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -103,15 +103,6 @@ void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, } } -void mpi_comm_split_type_wrapper(int *comm_f, int *split_type, int *key, - int *info_f, int *newcomm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Info info = get_c_info_from_fortran(*info_f); - MPI_Comm newcomm; - *ierror = MPI_Comm_split_type( comm, *split_type, *key , info, &newcomm); - *newcomm_f = MPI_Comm_c2f(newcomm); -} - void mpi_recv_wrapper(double *buf, int *count, int *datatype_f, int *source, int *tag, int *comm_f, int *status_f, int *ierror) { MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);