diff --git a/src/mpi.f90 b/src/mpi.f90 index 5f152bd..3349557 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -537,18 +537,36 @@ subroutine MPI_Dims_create_proc(nnodes, ndims, dims, ierror) end subroutine subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) - use mpi_c_bindings, only: c_mpi_cart_sub + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_cart_sub, c_mpi_comm_f2c, c_mpi_comm_c2f integer, intent(in) :: comm logical, intent(in) :: remain_dims(:) integer, intent(out) :: newcomm integer, optional, intent(out) :: ierror - integer :: remain_dims_i(size(remain_dims)) + integer, target :: remain_dims_i(size(remain_dims)) + type(c_ptr) :: c_comm, c_newcomm + integer :: local_ierr + type(c_ptr) :: remain_dims_i_ptr + + c_comm = c_mpi_comm_f2c(comm) + where (remain_dims) remain_dims_i = 1 elsewhere remain_dims_i = 0 end where - call c_mpi_cart_sub(comm, remain_dims_i, newcomm, ierror) + remain_dims_i_ptr = c_loc(remain_dims_i) + local_ierr = c_mpi_cart_sub(c_comm, remain_dims_i_ptr, c_newcomm) + + newcomm = c_mpi_comm_c2f(c_newcomm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Cart_sub failed with error code: ", local_ierr + end if + end if end subroutine subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index bc59ad1..47c8529 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -234,12 +234,13 @@ function c_mpi_dims_create(nnodes, ndims, dims) bind(C, name="MPI_Dims_create") integer(c_int) :: c_mpi_dims_create end function - subroutine c_mpi_cart_sub(comm, remain_dims, newcomm, ierror) bind(C, name ="mpi_cart_sub_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(in) :: comm - integer(c_int), intent(in) :: remain_dims(*) - integer(c_int), intent(out) :: newcomm, ierror - end subroutine + function c_mpi_cart_sub(comm, remain_dims, newcomm) bind(C, name ="MPI_Cart_sub") + use iso_c_binding, only: c_int, c_ptr + type(c_ptr), value :: comm + type(c_ptr), value :: remain_dims + type(c_ptr), intent(out) :: newcomm + integer(c_int) :: c_mpi_cart_sub + end function function c_mpi_reduce(sendbuf, recvbuf, count, c_dtype, c_op, root, c_comm) & bind(C, name="MPI_Reduce") diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 713b5b6..ed0baa6 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -188,10 +188,3 @@ void mpi_cart_shift_wrapper(int * comm_f, int * dir, int * disp, int * rank_sour MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Cart_shift(comm, *dir, *disp, rank_source, rank_dest); } - -void mpi_cart_sub_wrapper(int * comm_f, int * rmains_dims, int * newcomm_f, int * ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Comm newcomm = MPI_COMM_NULL; - *ierror = MPI_Cart_sub(comm, rmains_dims, &newcomm); - *newcomm_f = MPI_Comm_c2f(newcomm); -} diff --git a/tests/cart_sub.f90 b/tests/cart_sub.f90 index 4cf4b20..c2e8b09 100644 --- a/tests/cart_sub.f90 +++ b/tests/cart_sub.f90 @@ -1,4 +1,4 @@ -program main +program cart_sub use mpi implicit none @@ -34,7 +34,7 @@ program main call MPI_Init_thread(MPI_THREAD_FUNNELED, tcheck, ierr) if (ierr /= MPI_SUCCESS) then print *, "Error initializing MPI" - stop + error stop end if ! Get rank and size in the global communicator @@ -45,14 +45,14 @@ program main call MPI_Dims_create(size, 2, dims, ierr) if (ierr /= MPI_SUCCESS) then print *, "Error creating dimensions" - stop + error stop end if ! Create a Cartesian communicator call MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, reorder, comm_cart, ierr) if (ierr /= MPI_SUCCESS) then print *, "Error creating Cartesian communicator" - stop + error stop end if ! Get new rank in the Cartesian communicator @@ -68,7 +68,7 @@ program main call MPI_Cart_sub(comm_cart, remain_dims, comm_new, ierr) if (ierr /= MPI_SUCCESS) then print *, "Error creating sub-communicator" - stop + error stop end if ! Get the size of the new communicator @@ -81,7 +81,7 @@ program main call MPI_Finalize(errs) if (errs /= MPI_SUCCESS) then print *, "Error finalizing MPI" - stop + error stop end if -end program main \ No newline at end of file +end program cart_sub