Skip to content

Call MPI_AllReduce instead of C-wrapper for the 1D-recvbuf #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ module mpi

interface MPI_Allreduce
module procedure MPI_Allreduce_scalar
module procedure MPI_Allreduce_1d
module procedure MPI_Allreduce_1D_recv_proc
module procedure MPI_Allreduce_array_real
module procedure MPI_Allreduce_array_int
end interface
Expand Down Expand Up @@ -431,14 +431,36 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier
end if
end subroutine

subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_1d
real(8), intent(in) :: sendbuf
real(8), dimension(:), intent(out) :: recvbuf
subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c
real(8), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
integer, intent(out), optional :: ierror
call c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror)
end subroutine
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm
integer(c_int) :: local_ierr

if (sendbuf == MPI_IN_PLACE) then
sendbuf_ptr = c_mpi_in_place_f2c(sendbuf)
else
sendbuf_ptr = c_loc(sendbuf)
end if
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_comm = c_mpi_comm_f2c(comm)

local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid ambiguity, can we rename this to c_mpi_allreduce? (and similarly the call in MPI_Allreduce_scalar as well)


if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Allreduce_scalar failed with error code: ", local_ierr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we can rename this to MPI_Allreduce_1D_recv_proc maybe?

end if
end if
end subroutine MPI_Allreduce_1D_recv_proc

subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_array_real
Expand Down