Skip to content

Remoce C-wrapper for MPI_Comm_split #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -448,14 +448,36 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror)
end subroutine

subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror)
use mpi_c_bindings, only: c_mpi_comm_split_type
integer :: comm
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_info_f2c
integer, intent(in) :: comm
integer, intent(in) :: split_type, key
integer, intent(in) :: info
integer, intent(out) :: newcomm
integer, optional, intent(out) :: ierror
call c_mpi_comm_split_type(comm, split_type, key, info, newcomm, ierror)
end subroutine

integer(c_int) :: local_ierr
type(c_ptr) :: c_comm, c_info, c_new_comm

! Convert Fortran communicator and info handles to C pointers.
c_comm = c_mpi_comm_f2c(comm)
c_info = c_mpi_info_f2c(info)

! Call the native MPI_Comm_split_type.
local_ierr = c_mpi_comm_split_type(c_comm, split_type, key, c_info, c_new_comm)

! Convert the new communicator C handle back to a Fortran integer handle.
newcomm = c_mpi_comm_c2f(c_new_comm)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= 0) then
print *, "MPI_Comm_split_type failed with error code: ", local_ierr
end if
end if

end subroutine MPI_Comm_split_type_proc

subroutine MPI_Recv_proc(buf, count, datatype, source, tag, comm, status, ierror)
use mpi_c_bindings, only: c_mpi_recv
Expand Down
23 changes: 15 additions & 8 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran")
type(c_ptr) :: c_mpi_op_f2c
end function c_mpi_op_f2c

function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: info_f
type(c_ptr) :: c_mpi_info_f2c
end function c_mpi_info_f2c

function c_mpi_init(argc, argv) bind(C, name="MPI_Init")
use iso_c_binding, only : c_int, c_ptr
!> TODO: is the intent need to be explicitly specified
Expand Down Expand Up @@ -168,14 +174,15 @@ function c_mpi_comm_rank(comm, rank) bind(C, name="MPI_Comm_rank")
integer(c_int) :: c_mpi_comm_rank
end function c_mpi_comm_rank

subroutine c_mpi_comm_split_type(comm, split_type, key, info, newcomm, ierror) bind(C, name="mpi_comm_split_type_wrapper")
use iso_c_binding, only: c_int
integer(c_int) :: comm
integer(c_int), intent(in) :: split_type, key
integer(c_int), intent(in) :: info
integer(c_int), intent(out) :: newcomm
integer(c_int), optional, intent(out) :: ierror
end subroutine
function c_mpi_comm_split_type(c_comm, split_type, key, c_info, new_comm) bind(C, name="MPI_Comm_split_type")
use iso_c_binding, only: c_ptr, c_int
type(c_ptr), value :: c_comm
integer(c_int), value :: split_type
integer(c_int), value :: key
type(c_ptr), value :: c_info
type(c_ptr) :: new_comm
integer(c_int) :: c_mpi_comm_split_type
end function c_mpi_comm_split_type

subroutine c_mpi_recv(buf, count, datatype, source, tag, comm, status, ierror) bind(C, name="mpi_recv_wrapper")
use iso_c_binding, only: c_int, c_double
Expand Down
9 changes: 0 additions & 9 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,6 @@ void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count,
}
}

void mpi_comm_split_type_wrapper(int *comm_f, int *split_type, int *key,
int *info_f, int *newcomm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Info info = get_c_info_from_fortran(*info_f);
MPI_Comm newcomm;
*ierror = MPI_Comm_split_type( comm, *split_type, *key , info, &newcomm);
*newcomm_f = MPI_Comm_c2f(newcomm);
}

void mpi_recv_wrapper(double *buf, int *count, int *datatype_f, int *source,
int *tag, int *comm_f, int *status_f, int *ierror) {
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);
Expand Down