Skip to content

feat: Implement wrappers for MPI_Gatherv #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ module mpi
module procedure MPI_Allreduce_1D_int_proc
end interface

interface MPI_Gatherv
module procedure MPI_Gatherv_int
module procedure MPI_Gatherv_real
end interface MPI_Gatherv

interface MPI_Wtime
module procedure MPI_Wtime_proc
end interface
Expand Down Expand Up @@ -731,6 +736,88 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s

end subroutine MPI_Recv_StatusIgnore_proc

subroutine MPI_Gatherv_int(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, root, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_gatherv, c_mpi_in_place
integer, dimension(:), intent(in), target :: sendbuf
integer, intent(in) :: sendcount
integer, intent(in) :: sendtype
integer, dimension(:), intent(out), target :: recvbuf
integer, dimension(:), intent(in) :: recvcounts
integer, dimension(:), intent(in) :: displs
integer, intent(in) :: recvtype
integer, intent(in) :: root
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
type(c_ptr) :: c_sendbuf, c_recvbuf
integer(c_int) :: local_ierr

if (sendbuf(1) == MPI_IN_PLACE) then
c_sendbuf = c_MPI_IN_PLACE
else
c_sendbuf = c_loc(sendbuf)
end if

c_recvbuf = c_loc(recvbuf)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
c_comm = handle_mpi_comm_f2c(comm)

! Call C MPI_Gatherv
local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, &
c_recvbuf, recvcounts, displs, c_recvtype, &
root, c_comm)

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Gatherv failed with error code: ", local_ierr
end if
end subroutine MPI_Gatherv_int

subroutine MPI_Gatherv_real(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, root, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_gatherv, c_mpi_in_place
real(8), dimension(:), intent(in), target :: sendbuf
integer, intent(in) :: sendcount
integer, intent(in) :: sendtype
real(8), dimension(:), intent(out), target :: recvbuf
integer, dimension(:), intent(in) :: recvcounts
integer, dimension(:), intent(in) :: displs
integer, intent(in) :: recvtype
integer, intent(in) :: root
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
type(c_ptr) :: c_sendbuf, c_recvbuf
integer(c_int) :: local_ierr

if (sendbuf(1) == MPI_IN_PLACE) then
c_sendbuf = c_MPI_IN_PLACE
else
c_sendbuf = c_loc(sendbuf)
end if

c_recvbuf = c_loc(recvbuf)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
c_comm = handle_mpi_comm_f2c(comm)

! Call C MPI_Gatherv
local_ierr = c_mpi_gatherv(c_sendbuf, sendcount, c_sendtype, &
c_recvbuf, recvcounts, displs, c_recvtype, &
root, c_comm)

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Gatherv failed with error code: ", local_ierr
end if
end subroutine MPI_Gatherv_real

subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f, c_mpi_statuses_ignore
Expand Down
15 changes: 15 additions & 0 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,21 @@ function c_mpi_ssend(buf, count, datatype, dest, tag, comm) bind(C, name="MPI_Ss
integer(c_int) :: c_mpi_ssend
end function

function c_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, root, comm) bind(C, name="MPI_Gatherv")
use iso_c_binding, only: c_int, c_ptr
type(c_ptr), value :: sendbuf
integer(c_int), value :: sendcount
integer(kind=MPI_HANDLE_KIND), value :: sendtype
type(c_ptr), value :: recvbuf
integer(c_int), dimension(*), intent(in) :: recvcounts
integer(c_int), dimension(*), intent(in) :: displs
integer(kind=MPI_HANDLE_KIND), value :: recvtype
integer(c_int), value :: root
integer(kind=MPI_HANDLE_KIND), value :: comm
integer(c_int) :: c_mpi_gatherv
end function c_mpi_gatherv

function c_mpi_cart_create(comm_old, ndims, dims, periods, reorder, comm_cart) bind(C, name="MPI_Cart_create")
use iso_c_binding, only: c_int, c_ptr
integer(kind=MPI_HANDLE_KIND), value :: comm_old
Expand Down
69 changes: 69 additions & 0 deletions tests/gatherv_1.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
program gatherv_1
use mpi
implicit none
integer :: ierr, rank, size, root
integer, allocatable :: sendbuf(:), recvbuf(:)
integer, allocatable :: recvcounts(:), displs(:)
integer :: sendcount, i, total
logical :: error

! Initialize MPI
call MPI_Init(ierr)
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)
call MPI_Comm_size(MPI_COMM_WORLD, size, ierr)

! Root process
root = 0

! Each process sends 'rank + 1' integers
sendcount = rank + 1
allocate(sendbuf(sendcount))
do i = 1, sendcount
sendbuf(i) = rank * 100 + i ! Unique values per process
end do

! Allocate receive buffers on root
if (rank == root) then
allocate(recvcounts(size))
allocate(displs(size))
total = 0
do i = 1, size
recvcounts(i) = i ! Process i-1 sends i elements
displs(i) = total ! Displacement in recvbuf
total = total + recvcounts(i)
end do
allocate(recvbuf(total))
recvbuf = 0
else
allocate(recvcounts(1), displs(1), recvbuf(1)) ! Dummy allocations for non-root
end if

! Perform gather
call MPI_Gatherv(sendbuf, sendcount, MPI_INTEGER, recvbuf, recvcounts, &
displs, MPI_INTEGER, root, MPI_COMM_WORLD, ierr)

! Verify results on root
error = .false.
if (rank == root) then
do i = 1, size
do sendcount = 1, i
if (recvbuf(displs(i) + sendcount) /= (i-1)*100 + sendcount) then
print *, "Error at rank ", i-1, " index ", sendcount, &
": expected ", (i-1)*100 + sendcount, &
", got ", recvbuf(displs(i) + sendcount)
error = .true.
error stop
end if
end do
end do
if (.not. error) then
print *, "MPI_Gatherv test passed on root"
end if
end if

! Clean up
deallocate(sendbuf, recvbuf, recvcounts, displs)
call MPI_Finalize(ierr)

if (error) stop 1
end program gatherv_1