Skip to content

remove C MPI wrapper function for MPI_Comm_size #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,21 @@ subroutine MPI_Finalize_proc(ierr)
end if
end subroutine

subroutine MPI_Comm_size_proc(comm, size, ierr)
use mpi_c_bindings, only: c_mpi_comm_size
subroutine MPI_Comm_size_proc(comm, size, ierror)
use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c
use iso_c_binding, only: c_int, c_ptr
integer, intent(in) :: comm
integer, intent(out) :: size
integer, optional, intent(out) :: ierr
integer, optional, intent(out) :: ierror
integer :: local_ierr
if (present(ierr)) then
call c_mpi_comm_size(comm, size, ierr)
type(c_ptr) :: c_comm

c_comm = c_mpi_comm_f2c(comm)
local_ierr = c_mpi_comm_size(c_comm, size)
if (present(ierror)) then
ierror = local_ierr
else
call c_mpi_comm_size(comm, size, local_ierr)
if (local_ierr /= 0) then
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Comm_size failed with error code: ", local_ierr
end if
end if
Expand Down
21 changes: 11 additions & 10 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ module mpi_c_bindings
implicit none

interface
function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: comm_f
type(c_ptr) :: c_mpi_comm_f2c ! MPI_Comm as pointer
end function c_mpi_comm_f2c

function c_mpi_init(argc, argv) bind(C, name="MPI_Init")
use iso_c_binding, only : c_int, c_ptr
!> TODO: is the intent need to be explicitly specified
Expand All @@ -25,12 +31,12 @@ integer(c_int) function c_mpi_finalize() bind(C, name="MPI_Finalize")
use iso_c_binding, only : c_int
end function c_mpi_finalize

subroutine c_mpi_comm_size(comm, size, ierr) bind(C, name="mpi_comm_size_wrapper")
use iso_c_binding, only: c_int
integer(c_int), intent(in) :: comm
function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size")
use iso_c_binding, only: c_int, c_ptr
type(c_ptr), value :: comm
integer(c_int), intent(out) :: size
integer(c_int), intent(out) :: ierr
end subroutine c_mpi_comm_size
integer(c_int) :: c_mpi_comm_size
end function c_mpi_comm_size

subroutine c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_int_wrapper")
use iso_c_binding, only: c_int
Expand Down Expand Up @@ -132,11 +138,6 @@ function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime")
use iso_c_binding, only: c_double
real(c_double) :: time
end function
function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: comm_f
type(c_ptr) :: c_mpi_comm_f2c ! MPI_Comm as pointer
end function c_mpi_comm_f2c

function c_mpi_barrier(comm) bind(C, name="MPI_Barrier")
use iso_c_binding, only: c_ptr, c_int
Expand Down
5 changes: 0 additions & 5 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) {
}
}

void mpi_comm_size_wrapper(int *comm_f, int *size, int *ierr) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
*ierr = MPI_Comm_size(comm, size);
}

void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype;
Expand Down