Skip to content

Remove C-wrapper for MPI_BCAST #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -200,24 +200,56 @@ subroutine MPI_Comm_size_proc(comm, size, ierror)
end subroutine

subroutine MPI_Bcast_int(buffer, count, datatype, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_bcast_int
integer :: buffer
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c
use iso_c_binding, only: c_int, c_ptr, c_loc
integer, target :: buffer
integer, intent(in) :: count, root
integer, intent(in) :: datatype
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
call c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror)
end subroutine
type(c_ptr) :: c_comm, c_datatype
integer :: local_ierr
type(c_ptr) :: buffer_ptr

c_comm = c_mpi_comm_f2c(comm)
c_datatype = c_mpi_datatype_f2c(datatype)
buffer_ptr = c_loc(buffer)
local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Bcast_int failed with error code: ", local_ierr
end if
end if
end subroutine MPI_Bcast_int

subroutine MPI_Bcast_real(buffer, count, datatype, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_bcast_real
real(8), dimension(:, :) :: buffer
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c
use iso_c_binding, only: c_int, c_ptr, c_loc
real(8), dimension(:, :), target :: buffer
integer, intent(in) :: count, root
integer, intent(in) :: datatype
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
call c_mpi_bcast_real(buffer, count, datatype, root, comm, ierror)
end subroutine
type(c_ptr) :: c_comm, c_datatype
integer :: local_ierr
type(c_ptr) :: buffer_ptr

c_comm = c_mpi_comm_f2c(comm)
c_datatype = c_mpi_datatype_f2c(datatype)
buffer_ptr = c_loc(buffer)
local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Bcast_real failed with error code: ", local_ierr
end if
end if
end subroutine MPI_Bcast_real

subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror)
use mpi_c_bindings, only: c_mpi_allgather_int
Expand Down
32 changes: 15 additions & 17 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f")
integer :: c_mpi_comm_c2f
end function

function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: datatype
type(c_ptr) :: c_mpi_datatype_f2c
end function c_mpi_datatype_f2c

function c_mpi_init(argc, argv) bind(C, name="MPI_Init")
use iso_c_binding, only : c_int, c_ptr
!> TODO: is the intent need to be explicitly specified
Expand Down Expand Up @@ -44,23 +50,15 @@ function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size")
integer(c_int) :: c_mpi_comm_size
end function c_mpi_comm_size

subroutine c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_int_wrapper")
use iso_c_binding, only: c_int
integer(c_int) :: buffer
integer(c_int), intent(in) :: count, root
integer(c_int), intent(in) :: datatype
integer(c_int), intent(in) :: comm
integer(c_int), optional, intent(out) :: ierror
end subroutine

subroutine c_mpi_bcast_real(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_real_wrapper")
use iso_c_binding, only : c_int, c_double
real(c_double), dimension(*) :: buffer
integer(c_int), intent(in) :: count, root
integer(c_int), intent(in) :: datatype
integer(c_int), intent(in) :: comm
integer(c_int), optional, intent(out) :: ierror
end subroutine
function c_mpi_bcast(buffer, count, datatype, root, comm) bind(C, name="MPI_Bcast")
use iso_c_binding, only : c_ptr, c_int
type(c_ptr), value :: buffer
integer(c_int), value :: count
type(c_ptr), value :: datatype
integer(c_int), value :: root
type(c_ptr), value :: comm
integer(c_int) :: c_mpi_bcast
end function c_mpi_bcast

subroutine c_mpi_allgather_int(sendbuf, sendcount, sendtype, recvbuf, &
recvcount, recvtype, comm, ierror) bind(C, name="mpi_allgather_int_wrapper")
Expand Down
14 changes: 1 addition & 13 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,6 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) {
}
}

void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);
*ierror = MPI_Bcast(buffer, *count, datatype, *root, comm);
}

void mpi_bcast_real_wrapper(double *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);
*ierror = MPI_Bcast(buffer, *count, datatype, *root, comm);
}

void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype_f,
int *recvbuf, int *recvcount, int *recvtype_f,
int *comm_f, int *ierror) {
Expand Down Expand Up @@ -200,7 +188,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f,

void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest,
int *tag, int *comm_f, int *ierror) {
MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f);
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
*ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm);
Expand Down
31 changes: 31 additions & 0 deletions tests/bcast_1.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
program test_bcast
use mpi
implicit none

integer :: ierror, rank, size, comm, root, n, i, d
comm = MPI_COMM_WORLD
d = 0
! Initialize MPI
call MPI_Init(ierror)

! Get our rank and the total number of processes
call MPI_Comm_rank(comm, rank, ierror)
call MPI_Comm_size(comm, size, ierror)

root = 0
if (rank == root) then
d = 1
end if

! Broadcast the integer from root=0 to all processes
call MPI_Bcast(d, 1, MPI_INTEGER, root, comm, ierror)
if (ierror /= MPI_SUCCESS) then
print *, "Error in MPI_Bcast:", ierror
end if

! Print result on each rank
print *, "Rank=", rank, " received integer=", d

! Finalize MPI
call MPI_Finalize(ierror)
end program test_bcast