diff --git a/src/mpi.f90 b/src/mpi.f90 index d498c91..38ec15c 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -113,18 +113,21 @@ module mpi subroutine MPI_Init_proc(ierr) use mpi_c_bindings, only: c_mpi_init - use iso_c_binding, only : c_int + use iso_c_binding, only : c_int, c_ptr, c_null_ptr integer, optional, intent(out) :: ierr - integer :: local_ierr + integer(c_int) :: local_ierr + integer(c_int) :: argc + type(c_ptr) :: argv = c_null_ptr + argc = 0 + ! Call C MPI_Init directly with argc=0, argv=NULL + local_ierr = c_mpi_init(argc, argv) + if (present(ierr)) then - call c_mpi_init(ierr) - else - call c_mpi_init(local_ierr) - if (local_ierr /= 0) then - print *, "MPI_Init failed with error code: ", local_ierr - end if + ierr = int(local_ierr) + else if (local_ierr /= 0) then + print *, "MPI_Init failed with error code: ", local_ierr end if - end subroutine + end subroutine MPI_Init_proc subroutine MPI_Init_thread_proc(required, provided, ierr) use mpi_c_bindings, only : c_mpi_init_thread diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index c7ca2cf..cebeedd 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -2,10 +2,15 @@ module mpi_c_bindings implicit none interface - subroutine c_mpi_init(ierr) bind(C, name="mpi_init_wrapper") - use iso_c_binding, only: c_int - integer(c_int), intent(out) :: ierr - end subroutine c_mpi_init + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") + use iso_c_binding, only : c_int, c_ptr + !> TODO: is the intent need to be explicitly specified + !> as 'intent(inout)'? Though, currently LFortran + !> errors with this + integer(c_int) :: argc + type(c_ptr) :: argv + integer(c_int) :: c_mpi_init + end function c_mpi_init subroutine c_mpi_init_thread(required, provided, ierr) bind(C, name="mpi_init_thread_wrapper") use iso_c_binding, only: c_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 4ab4ee8..a62d9e0 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -4,11 +4,11 @@ #define MPI_STATUS_SIZE 5 -void mpi_init_wrapper(int *ierr) { - int argc = 0; - char **argv = NULL; - *ierr = MPI_Init(&argc, &argv); -} +// void mpi_init_wrapper(int *ierr) { +// int argc = 0; +// char **argv = NULL; +// *ierr = MPI_Init(&argc, &argv); +// } void mpi_init_thread_wrapper(int *required, int *provided, int *ierr) { int argc = 0; diff --git a/tests/init_1.f90 b/tests/init_1.f90 new file mode 100644 index 0000000..e415161 --- /dev/null +++ b/tests/init_1.f90 @@ -0,0 +1,36 @@ +program test_mpi_init_rank_size + use mpi + implicit none + integer :: ierr, rank, size + + ! Initialize MPI + call MPI_Init(ierr) + if (ierr /= MPI_SUCCESS) then + print *, "MPI_Init failed with error code: ", ierr + stop 1 + end if + + ! Get rank and size + call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr) + if (ierr /= MPI_SUCCESS) then + print *, "MPI_Comm_rank failed with error code: ", ierr + call MPI_Finalize(ierr) + stop 1 + end if + + call MPI_Comm_size(MPI_COMM_WORLD, size, ierr) + if (ierr /= MPI_SUCCESS) then + print *, "MPI_Comm_size failed with error code: ", ierr + call MPI_Finalize(ierr) + stop 1 + end if + + print *, "Hello from rank ", rank, " of ", size + + ! Finalize MPI + call MPI_Finalize(ierr) + if (ierr /= MPI_SUCCESS) then + print *, "MPI_Finalize failed with error code: ", ierr + stop 1 + end if +end program test_mpi_init_rank_size \ No newline at end of file