From 2f594e7fda068c39721671c0b8e6fcb3ee3677ef Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Mon, 31 Mar 2025 15:15:06 +0530 Subject: [PATCH 1/4] Test: Add MPI_Bcast --- tests/bcast_1.f90 | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/bcast_1.f90 diff --git a/tests/bcast_1.f90 b/tests/bcast_1.f90 new file mode 100644 index 0000000..e34e245 --- /dev/null +++ b/tests/bcast_1.f90 @@ -0,0 +1,39 @@ +program test_bcast + use mpi + implicit none + + integer :: ierror, rank, size, comm, root, n, i + real(8), allocatable :: arr(:,:) + comm = MPI_COMM_WORLD + + ! Initialize MPI + call MPI_Init(ierror) + + ! Get our rank and the total number of processes + call MPI_Comm_rank(comm, rank, ierror) + call MPI_Comm_size(comm, size, ierror) + + ! Decide how large an array to broadcast + n = 5 + allocate(arr(n,n)) + + ! Only the root (rank=0) initializes the data + root = 0 + if (rank == root) then + arr = reshape([(i, i=1,n*n)], shape=[n, n]) ! Initialize as a 2D array + else + arr = -999 ! fill with dummy values to see if Bcast overwrites them + end if + + ! Broadcast the array from root=0 to all processes + call MPI_Bcast(arr, n, MPI_REAL8, root, comm, ierror) + if (ierror /= MPI_SUCCESS) then + print *, "Error in MPI_Bcast:", ierror + end if + + ! Print result on each rank + print *, "Rank=", rank, " received arr=", arr + + ! Finalize MPI + call MPI_Finalize(ierror) +end program test_bcast \ No newline at end of file From 53ff8a03143a24321f6a5c7ca28c767f50e0edea Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Mon, 31 Mar 2025 15:15:20 +0530 Subject: [PATCH 2/4] Use unified c_mpi_bcast for all datatypes - This is due to the fact that we would transfer the type(C_PTR) which will corresponds to the void* in the MPI_Bcast args -Also create C-binding wrapper for get_c_datatype_from_fortran --- src/mpi_c_bindings.f90 | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index df29b87..2db4fdb 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -14,6 +14,12 @@ function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f") integer :: c_mpi_comm_c2f end function + function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: datatype + type(c_ptr) :: c_mpi_datatype_f2c + end function c_mpi_datatype_f2c + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified @@ -44,23 +50,15 @@ function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size") integer(c_int) :: c_mpi_comm_size end function c_mpi_comm_size - subroutine c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_int_wrapper") - use iso_c_binding, only: c_int - integer(c_int) :: buffer - integer(c_int), intent(in) :: count, root - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), optional, intent(out) :: ierror - end subroutine - - subroutine c_mpi_bcast_real(buffer, count, datatype, root, comm, ierror) bind(C, name="mpi_bcast_real_wrapper") - use iso_c_binding, only : c_int, c_double - real(c_double), dimension(*) :: buffer - integer(c_int), intent(in) :: count, root - integer(c_int), intent(in) :: datatype - integer(c_int), intent(in) :: comm - integer(c_int), optional, intent(out) :: ierror - end subroutine + function c_mpi_bcast(buffer, count, datatype, root, comm) bind(C, name="MPI_Bcast") + use iso_c_binding, only : c_ptr, c_int + type(c_ptr), value :: buffer + integer(c_int), value :: count + type(c_ptr), value :: datatype + integer(c_int), value :: root + type(c_ptr), value :: comm + integer(c_int) :: c_mpi_bcast + end function c_mpi_bcast subroutine c_mpi_allgather_int(sendbuf, sendcount, sendtype, recvbuf, & recvcount, recvtype, comm, ierror) bind(C, name="mpi_allgather_int_wrapper") From df811e8a00a5f045dfe22e90f21efb672555b245 Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Mon, 31 Mar 2025 15:15:53 +0530 Subject: [PATCH 3/4] Remove C-wrapper for MPI_BCAST --- src/mpi.f90 | 48 +++++++++++++++++++++++++++++++++++++++-------- src/mpi_wrapper.c | 14 +------------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 88505b8..1a13c6a 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -200,24 +200,56 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) end subroutine subroutine MPI_Bcast_int(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast_int - integer :: buffer + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c + use iso_c_binding, only: c_int, c_ptr, c_loc + integer, target :: buffer integer, intent(in) :: count, root integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - call c_mpi_bcast_int(buffer, count, datatype, root, comm, ierror) - end subroutine + type(c_ptr) :: c_comm, c_datatype + integer :: local_ierr + type(c_ptr) :: buffer_ptr + + c_comm = c_mpi_comm_f2c(comm) + c_datatype = c_mpi_datatype_f2c(datatype) + buffer_ptr = c_loc(buffer) + local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Bcast_int failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Bcast_int subroutine MPI_Bcast_real(buffer, count, datatype, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_bcast_real - real(8), dimension(:, :) :: buffer + use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c + use iso_c_binding, only: c_int, c_ptr, c_loc + real(8), dimension(:, :), target :: buffer integer, intent(in) :: count, root integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - call c_mpi_bcast_real(buffer, count, datatype, root, comm, ierror) - end subroutine + type(c_ptr) :: c_comm, c_datatype + integer :: local_ierr + type(c_ptr) :: buffer_ptr + + c_comm = c_mpi_comm_f2c(comm) + c_datatype = c_mpi_datatype_f2c(datatype) + buffer_ptr = c_loc(buffer) + local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Bcast_real failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Bcast_real subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror) use mpi_c_bindings, only: c_mpi_allgather_int diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index eb78d02..bfb1bc5 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -55,18 +55,6 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { } } -void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); -} - -void mpi_bcast_real_wrapper(double *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { - MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); - *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); -} - void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype_f, int *recvbuf, int *recvcount, int *recvtype_f, int *comm_f, int *ierror) { @@ -200,7 +188,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f, void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *ierror) { - MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f); + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm); From 3772493f94e658c827e8dccb86766a252898935a Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Mon, 31 Mar 2025 15:36:58 +0530 Subject: [PATCH 4/4] Use integer for bcast --- tests/bcast_1.f90 | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/bcast_1.f90 b/tests/bcast_1.f90 index e34e245..f711d9d 100644 --- a/tests/bcast_1.f90 +++ b/tests/bcast_1.f90 @@ -2,10 +2,9 @@ program test_bcast use mpi implicit none - integer :: ierror, rank, size, comm, root, n, i - real(8), allocatable :: arr(:,:) + integer :: ierror, rank, size, comm, root, n, i, d comm = MPI_COMM_WORLD - + d = 0 ! Initialize MPI call MPI_Init(ierror) @@ -13,26 +12,19 @@ program test_bcast call MPI_Comm_rank(comm, rank, ierror) call MPI_Comm_size(comm, size, ierror) - ! Decide how large an array to broadcast - n = 5 - allocate(arr(n,n)) - - ! Only the root (rank=0) initializes the data root = 0 if (rank == root) then - arr = reshape([(i, i=1,n*n)], shape=[n, n]) ! Initialize as a 2D array - else - arr = -999 ! fill with dummy values to see if Bcast overwrites them + d = 1 end if - ! Broadcast the array from root=0 to all processes - call MPI_Bcast(arr, n, MPI_REAL8, root, comm, ierror) + ! Broadcast the integer from root=0 to all processes + call MPI_Bcast(d, 1, MPI_INTEGER, root, comm, ierror) if (ierror /= MPI_SUCCESS) then print *, "Error in MPI_Bcast:", ierror end if ! Print result on each rank - print *, "Rank=", rank, " received arr=", arr + print *, "Rank=", rank, " received integer=", d ! Finalize MPI call MPI_Finalize(ierror)