diff --git a/src/mpi.f90 b/src/mpi.f90 index 48d1e13..25e16d6 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -612,12 +612,25 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) end subroutine subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ierror) - use mpi_c_bindings, only: c_mpi_cart_shift + use iso_c_binding, only: c_int, c_ptr + use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c integer, intent(in) :: comm integer, intent(in) :: direction, disp integer, intent(out) :: rank_source, rank_dest integer, optional, intent(out) :: ierror - call c_mpi_cart_shift(comm, direction, disp, rank_source, rank_dest, ierror) + type(c_ptr) :: c_comm + integer(c_int) :: local_ierr + + c_comm = c_mpi_comm_f2c(comm) + local_ierr = c_mpi_cart_shift(c_comm, direction, disp, rank_source, rank_dest) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Cart_shift failed with error code: ", local_ierr + end if + end if end subroutine subroutine MPI_Dims_create_proc(nnodes, ndims, dims, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index ff36210..d4b9617 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -228,11 +228,13 @@ function c_mpi_cart_coords(comm, rank, maxdims, coords) bind(C, name="MPI_Cart_c integer(c_int) :: c_mpi_cart_coords end function - subroutine c_mpi_cart_shift(comm, direction, disp, rank_source, rank_dest, ierror) bind(C, name="mpi_cart_shift_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(in) :: comm, direction, disp - integer(c_int), intent(out) :: rank_source, rank_dest, ierror - end subroutine + function c_mpi_cart_shift(comm, direction, disp, rank_source, rank_dest) bind(C, name="MPI_Cart_shift") + use iso_c_binding, only: c_int, c_ptr + type(c_ptr), value :: comm + integer(c_int), value :: direction, disp + integer(c_int), intent(out) :: rank_source, rank_dest + integer(c_int) :: c_mpi_cart_shift + end function function c_mpi_dims_create(nnodes, ndims, dims) bind(C, name="MPI_Dims_create") use iso_c_binding, only: c_int, c_ptr diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 785192d..5fe6429 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -130,9 +130,3 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f, free(array_of_requests); free(array_of_statuses); } - -void mpi_cart_shift_wrapper(int * comm_f, int * dir, int * disp, int * rank_source, int * rank_dest, int * ierror) -{ - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - *ierror = MPI_Cart_shift(comm, *dir, *disp, rank_source, rank_dest); -} diff --git a/tests/cart_shift_1.f90 b/tests/cart_shift_1.f90 new file mode 100644 index 0000000..aea3c17 --- /dev/null +++ b/tests/cart_shift_1.f90 @@ -0,0 +1,50 @@ +program cart_shift_1 + use mpi + implicit none + + integer :: ierr, rank, size + integer :: comm_cart + integer :: dims(2) + logical :: periods(2) + integer :: coords(2) + integer :: left, right, up, down + + call MPI_INIT(ierr) + call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr) + call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr) + + dims(1) = 0 + dims(2) = 0 + periods(1) = .true. + periods(2) = .false. + + call MPI_DIMS_CREATE(size, 2, dims, ierr) + call MPI_CART_CREATE(MPI_COMM_WORLD, 2, dims, periods, & + .true., comm_cart, ierr) + + call MPI_CART_COORDS(comm_cart, rank, 2, coords, ierr) + + call MPI_CART_SHIFT(comm_cart, 0, 1, left, right, ierr) + + call MPI_CART_SHIFT(comm_cart, 1, 1, up, down, ierr) + + ! TODO: enable these checks in the future + ! if (coords(2) == 0 .and. up /= MPI_PROC_NULL) then + ! print *, 'Error: Rank ', rank, ' at (', coords(1), ',', coords(2), & + ! ') should have up = MPI_PROC_NULL but got ', up + ! error stop + ! end if + + ! if (coords(2) == dims(2)-1 .and. down /= MPI_PROC_NULL) then + ! print *, 'Error: Rank ', rank, ' at (', coords(1), ',', coords(2), & + ! ') should have down = MPI_PROC_NULL but got ', down + ! error stop + ! end if + + print *, & + 'Rank ', rank, ' at coords (', coords(1), ',', coords(2), & + ') neighbors: left=', left, ' right=', right, & + ' up=', up, ' down=', down + + call MPI_FINALIZE(ierr) +end program cart_shift_1