Skip to content

Remove C-Wrapper for MPI_Init_Thread #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -127,21 +127,32 @@ subroutine MPI_Init_proc(ierr)
end subroutine

subroutine MPI_Init_thread_proc(required, provided, ierr)
use mpi_c_bindings, only : c_mpi_init_thread
use iso_c_binding, only: c_int
use mpi_c_bindings, only: c_mpi_init_thread
use iso_c_binding, only: c_int, c_ptr, c_null_ptr
integer, intent(in) :: required
integer, intent(out) :: provided
integer, optional, intent(out) :: ierr
integer :: local_ierr
integer(c_int) :: local_ierr
integer(c_int) :: argc = 0
type(c_ptr) :: argv = c_null_ptr
integer(c_int) :: c_required
integer(c_int) :: c_provided

! Map Fortran MPI_THREAD_FUNNELED to C MPI_THREAD_FUNNELED if needed
c_required = int(required, c_int)

! Call C MPI_Init_thread directly
local_ierr = c_mpi_init_thread(argc, argv, required, provided)

! Copy output values back to Fortran
provided = int(c_provided)

if (present(ierr)) then
call c_mpi_init_thread(required, provided, ierr)
else
call c_mpi_init_thread(required, provided, local_ierr)
if (local_ierr /= 0) then
print *, "MPI_Init_thread failed with error code: ", local_ierr
end if
ierr = int(local_ierr)
else if (local_ierr /= 0) then
print *, "MPI_Init_thread failed with error code: ", local_ierr
end if
end subroutine
end subroutine MPI_Init_thread_proc

subroutine MPI_Finalize_proc(ierr)
use mpi_c_bindings, only: c_mpi_finalize
Expand Down
12 changes: 7 additions & 5 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ subroutine c_mpi_init(ierr) bind(C, name="mpi_init_wrapper")
integer(c_int), intent(out) :: ierr
end subroutine c_mpi_init

subroutine c_mpi_init_thread(required, provided, ierr) bind(C, name="mpi_init_thread_wrapper")
use iso_c_binding, only: c_int
integer(c_int), intent(in) :: required
function c_mpi_init_thread(argc, argv, required, provided) bind(C, name="MPI_Init_thread")
use iso_c_binding, only: c_int, c_ptr
integer(c_int) :: argc
type(c_ptr) :: argv
integer(c_int), value :: required
integer(c_int), intent(out) :: provided
integer(c_int), intent(out) :: ierr
end subroutine c_mpi_init_thread
integer(c_int) :: c_mpi_init_thread
end function c_mpi_init_thread

integer(c_int) function c_mpi_finalize() bind(C, name="MPI_Finalize")
use iso_c_binding, only : c_int
Expand Down
8 changes: 0 additions & 8 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@ void mpi_init_wrapper(int *ierr) {
*ierr = MPI_Init(&argc, &argv);
}

void mpi_init_thread_wrapper(int *required, int *provided, int *ierr) {
int argc = 0;
char **argv = NULL;

int thread_support = (*required == 1) ? MPI_THREAD_FUNNELED : *required;
*ierr = MPI_Init_thread(&argc, &argv, thread_support, provided);
}

void mpi_comm_size_wrapper(int *comm_f, int *size, int *ierr) {
MPI_Comm comm = MPI_Comm_f2c(*comm_f);
*ierr = MPI_Comm_size(comm, size);
Expand Down
18 changes: 18 additions & 0 deletions tests/init_2.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
program init_2
use mpi, only: MPI_Init_thread, MPI_Finalize, MPI_THREAD_FUNNELED
implicit none

integer :: provided, ierr

! Initialize MPI with thread support
call MPI_Init_thread(MPI_THREAD_FUNNELED, provided, ierr)

if (ierr /= 0) then
print *, "Error initializing MPI with threads"
error stop
end if
print *, "Running MPI with thread support"

! Finalize MPI
call MPI_Finalize(ierr)
end program init_2