Skip to content

Feat: Add Wrappers for MPI_AllGatherv #124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ module mpi
module procedure MPI_Waitall_proc
end interface

interface MPI_Allgatherv
module procedure MPI_Allgatherv_int
module procedure MPI_Allgatherv_real
end interface MPI_Allgatherv

interface MPI_Ssend
module procedure MPI_Ssend_proc
end interface
Expand Down Expand Up @@ -773,6 +778,89 @@ subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)

end subroutine MPI_Waitall_proc

subroutine MPI_Allgatherv_int(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allgatherv, c_mpi_in_place
integer, dimension(:), intent(in), target :: sendbuf
integer, intent(in) :: sendcount
integer, intent(in) :: sendtype
integer, dimension(:), intent(out), target :: recvbuf
integer, dimension(:), intent(in) :: recvcounts
integer, dimension(:), intent(in) :: displs
integer, intent(in) :: recvtype
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
type(c_ptr) :: c_sendbuf, c_recvbuf
integer(c_int) :: local_ierr

! Handle sendbuf (support MPI_IN_PLACE)
if (sendbuf(1) == MPI_IN_PLACE) then
c_sendbuf = c_MPI_IN_PLACE
else
c_sendbuf = c_loc(sendbuf)
end if
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should've a utility function for this now, which should do this work for us now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can add utility function like we have for comm_world and other things. Will make a subsequent PR for that.

c_recvbuf = c_loc(recvbuf)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
c_comm = handle_mpi_comm_f2c(comm)

! Call C MPI_Allgatherv
local_ierr = c_mpi_allgatherv(c_sendbuf, sendcount, c_sendtype, &
c_recvbuf, recvcounts, displs, c_recvtype, &
c_comm)

! Handle error
if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Allgatherv failed with error code: ", local_ierr
end if

end subroutine MPI_Allgatherv_int

subroutine MPI_Allgatherv_real(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allgatherv, c_mpi_in_place
real(8), dimension(:), intent(in), target :: sendbuf
integer, intent(in) :: sendcount
integer, intent(in) :: sendtype
real(8), dimension(:), intent(out), target :: recvbuf
integer, dimension(:), intent(in) :: recvcounts
integer, dimension(:), intent(in) :: displs
integer, intent(in) :: recvtype
integer, intent(in) :: comm
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype, c_comm
type(c_ptr) :: c_sendbuf, c_recvbuf
integer(c_int) :: local_ierr

if (sendbuf(1) == MPI_IN_PLACE) then
c_sendbuf = c_MPI_IN_PLACE
else
c_sendbuf = c_loc(sendbuf)
end if

c_recvbuf = c_loc(recvbuf)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
c_comm = handle_mpi_comm_f2c(comm)

! Call C MPI_Allgatherv
local_ierr = c_mpi_allgatherv(c_sendbuf, sendcount, c_sendtype, &
c_recvbuf, recvcounts, displs, c_recvtype, &
c_comm)

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Allgatherv failed with error code: ", local_ierr
end if

end subroutine MPI_Allgatherv_real

subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_ssend
Expand Down
14 changes: 14 additions & 0 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,20 @@ function c_mpi_cart_create(comm_old, ndims, dims, periods, reorder, comm_cart) b
integer(c_int) :: c_mpi_cart_create
end function

function c_mpi_allgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, &
displs, recvtype, comm) bind(C, name="MPI_Allgatherv")
use iso_c_binding, only: c_int, c_ptr
type(c_ptr), value :: sendbuf
integer(c_int), value :: sendcount
integer(kind=MPI_HANDLE_KIND), value :: sendtype
type(c_ptr), value :: recvbuf
integer(c_int), dimension(*) :: recvcounts
integer(c_int), dimension(*) :: displs
integer(kind=MPI_HANDLE_KIND), value :: recvtype
integer(kind=MPI_HANDLE_KIND), value :: comm
integer(c_int) :: c_mpi_allgatherv
end function c_mpi_allgatherv

function c_mpi_cart_coords(comm, rank, maxdims, coords) bind(C, name="MPI_Cart_coords")
use iso_c_binding, only: c_int, c_ptr
integer(kind=MPI_HANDLE_KIND), value :: comm
Expand Down
59 changes: 59 additions & 0 deletions tests/allgatherv_1.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
program allgatherv_1
use mpi
implicit none
integer :: ierr, rank, size
integer, allocatable :: sendbuf(:), recvbuf(:)
integer, allocatable :: recvcounts(:), displs(:)
integer :: sendcount, i, total
logical :: error

! Initialize MPI
call MPI_Init(ierr)
call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr)
call MPI_Comm_size(MPI_COMM_WORLD, size, ierr)

! Each process sends 'rank + 1' integers
sendcount = rank + 1
allocate(sendbuf(sendcount))
do i = 1, sendcount
sendbuf(i) = rank * 100 + i ! Unique values per process
end do

! All processes allocate receive buffers
allocate(recvcounts(size))
allocate(displs(size))
total = 0
do i = 1, size
recvcounts(i) = i ! Process i-1 sends i elements
displs(i) = total ! Displacement in recvbuf
total = total + recvcounts(i)
end do
allocate(recvbuf(total))
recvbuf = 0

! Perform allgather
call MPI_Allgatherv(sendbuf, sendcount, MPI_INTEGER, recvbuf, recvcounts, &
displs, MPI_INTEGER, MPI_COMM_WORLD, ierr)

! Verify results on all processes
error = .false.
do i = 1, size
do sendcount = 1, i
if (recvbuf(displs(i) + sendcount) /= (i-1)*100 + sendcount) then
print *, "Rank ", rank, ": Error at source rank ", i-1, &
" index ", sendcount, ": expected ", (i-1)*100 + sendcount, &
", got ", recvbuf(displs(i) + sendcount)
error = .true.
end if
end do
end do
if (.not. error) then
print *, "MPI_Allgatherv test passed on rank ", rank
end if

! Clean up
deallocate(sendbuf, recvbuf, recvcounts, displs)
call MPI_Finalize(ierr)

if (error) stop 1
end program allgatherv_1