Skip to content

Remove MPI_Allreduce_array_real C-wrapper #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ module mpi
interface MPI_Allreduce
module procedure MPI_Allreduce_scalar
module procedure MPI_Allreduce_1D_recv_proc
module procedure MPI_Allreduce_array_real
module procedure MPI_Allreduce_1D_real_proc
module procedure MPI_Allreduce_array_int
end interface

Expand Down Expand Up @@ -462,15 +462,32 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com
end if
end subroutine MPI_Allreduce_1D_recv_proc

subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_array_real
! Declare both send and recv as arrays:
real(8), dimension(:), intent(in) :: sendbuf
real(8), dimension(:), intent(out) :: recvbuf
subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c
real(8), dimension(:), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
integer, intent(out), optional :: ierror
call c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror)
end subroutine MPI_Allreduce_array_real
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm
integer(c_int) :: local_ierr

sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_comm = c_mpi_comm_f2c(comm)

local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr
end if
end if
end subroutine MPI_Allreduce_1D_real_proc

subroutine MPI_Allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_array_int
Expand Down
10 changes: 1 addition & 9 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,7 @@ function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) &
type(c_ptr), value :: datatype, op, comm
integer(c_int) :: c_mpi_allreduce
end function

subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
use iso_c_binding, only: c_int, c_double
real(c_double), intent(in) :: sendbuf
real(c_double), dimension(*), intent(out) :: recvbuf
integer(c_int), intent(in) :: count, datatype, op, comm
integer(c_int), intent(out), optional :: ierror
end subroutine


subroutine c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
Expand Down
22 changes: 0 additions & 22 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,6 @@ void* get_c_mpi_inplace_from_fortran(double sendbuf) {
return MPI_IN_PLACE;
}

void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

// I'm a little doubtful: as how would it identify that this part
// is supposed to be for MPI_SUM?
MPI_Op op = MPI_Op_f2c(*op_f);
/*
hard-code values here:
1. We've hard-coded op as "MPI_SUM" for now, as in POT3D codebase, it's always
used as MPI_SUM
2. the first argument (i.e. sendbuf) as "MPI_IN_PLACE" for now as it's always
used as such in POT3D codebase
*/
if (*sendbuf == FORTRAN_MPI_IN_PLACE) {
*ierror = MPI_Allreduce(MPI_IN_PLACE , recvbuf, *count, datatype, MPI_SUM, comm);
} else {
*ierror = MPI_Allreduce(sendbuf , recvbuf, *count, datatype, MPI_SUM, comm);
}
}

void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
Expand Down