diff --git a/src/mpi.f90 b/src/mpi.f90 index 88505b8..1a13c6a 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -200,24 +200,56 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) end subroutine subroutine MPI_Bcast_int(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast_int - integer :: buffer + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c + use iso_c_binding, only: c_int, c_ptr, c_loc + integer, target :: buffer integer, intent(in) :: count, root integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - call c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) - end subroutine + type(c_ptr) :: c_comm, c_datatype + integer :: local_ierr + type(c_ptr) :: buffer_ptr + + c_comm = c_mpi_comm_f2c(comm) + c_datatype = c_mpi_datatype_f2c(datatype) + buffer_ptr = c_loc(buffer) + local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Bcast_int failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Bcast_int subroutine MPI_Bcast_real(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast_real - real(8), dimension(:, :) :: buffer + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c + use iso_c_binding, only: c_int, c_ptr, c_loc + real(8), dimension(:, :), target :: buffer integer, intent(in) :: count, root integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - call c_mpi_bcast_real(buffer, count, datatype, root, comm, ierror) - end subroutine + type(c_ptr) :: c_comm, c_datatype + integer :: local_ierr + type(c_ptr) :: buffer_ptr + + c_comm = c_mpi_comm_f2c(comm) + c_datatype = c_mpi_datatype_f2c(datatype) + buffer_ptr = c_loc(buffer) + local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Bcast_real failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Bcast_real subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use mpi_c_bindings, only: c_mpi_allgather_int diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index df29b87..2db4fdb 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -14,6 +14,12 @@ function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f") integer :: c_mpi_comm_c2f end function + function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: datatype + type(c_ptr) :: c_mpi_datatype_f2c + end function c_mpi_datatype_f2c + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified @@ -44,23 +50,15 @@ function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size") integer(c_int) :: c_mpi_comm_size end function c_mpi_comm_size - subroutine c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_int_wrapper") - use iso_c_binding, only: c_int - integer(c_int) :: buffer - integer(c_int), intent(in) :: count, root - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), optional, intent(out) :: ierror - end subroutine - - subroutine c_mpi_bcast_real(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_real_wrapper") - use iso_c_binding, only : c_int, c_double - real(c_double), dimension(*) :: buffer - integer(c_int), intent(in) :: count, root - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_bcast(buffer, count, datatype, root, comm) bind(C, name="MPI_Bcast") + use iso_c_binding, only : c_ptr, c_int + type(c_ptr), value :: buffer + integer(c_int), value :: count + type(c_ptr), value :: datatype + integer(c_int), value :: root + type(c_ptr), value :: comm + integer(c_int) :: c_mpi_bcast + end function c_mpi_bcast subroutine c_mpi_allgather_int(sendbuf, sendcount, sendtype, recvbuf, & recvcount, recvtype, comm, ierror) bind(C, name="mpi_allgather_int_wrapper") diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index eb78d02..bfb1bc5 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -55,18 +55,6 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { } } -void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); -} - -void mpi_bcast_real_wrapper(double *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); -} - void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype_f, int *recvbuf, int *recvcount, int *recvtype_f, int *comm_f, int *ierror) { @@ -200,7 +188,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f, void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *ierror) { - MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f); + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm); diff --git a/tests/bcast_1.f90 b/tests/bcast_1.f90 new file mode 100644 index 0000000..f711d9d --- /dev/null +++ b/tests/bcast_1.f90 @@ -0,0 +1,31 @@ +program test_bcast + use mpi + implicit none + + integer :: ierror, rank, size, comm, root, n, i, d + comm = MPI_COMM_WORLD + d = 0 + ! Initialize MPI + call MPI_Init(ierror) + + ! Get our rank and the total number of processes + call MPI_Comm_rank(comm, rank, ierror) + call MPI_Comm_size(comm, size, ierror) + + root = 0 + if (rank == root) then + d = 1 + end if + + ! Broadcast the integer from root=0 to all processes + call MPI_Bcast(d, 1, MPI_INTEGER, root, comm, ierror) + if (ierror /= MPI_SUCCESS) then + print *, "Error in MPI_Bcast:", ierror + end if + + ! Print result on each rank + print *, "Rank=", rank, " received integer=", d + + ! Finalize MPI + call MPI_Finalize(ierror) +end program test_bcast \ No newline at end of file