Skip to content

remove C wrapper of MPI_Allreduce_scalar #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,34 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr
end subroutine

subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_scalar
real(8), intent(in) :: sendbuf
real(8), intent(out) :: recvbuf
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c
real(8), intent(in), target :: sendbuf
real(8), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
integer, intent(out), optional :: ierror
call c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm
integer(c_int) :: local_ierr

if (sendbuf == MPI_IN_PLACE) then
sendbuf_ptr = c_mpi_in_place_f2c(sendbuf)
else
sendbuf_ptr = c_loc(sendbuf)
end if
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_comm = c_mpi_comm_f2c(comm)

local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Allreduce_scalar failed with error code: ", local_ierr
end if
end if
end subroutine

subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror)
Expand Down
23 changes: 15 additions & 8 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran")
type(c_ptr) :: c_mpi_info_f2c
end function c_mpi_info_f2c

function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran")
use iso_c_binding, only: c_double, c_ptr
real(c_double), value :: in_place_f
type(c_ptr) :: c_mpi_in_place_f2c
end function c_mpi_in_place_f2c

function c_mpi_init(argc, argv) bind(C, name="MPI_Init")
use iso_c_binding, only : c_int, c_ptr
!> TODO: is the intent need to be explicitly specified
Expand Down Expand Up @@ -120,14 +126,15 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n
integer(c_int) :: c_mpi_irecv
end function

subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
use iso_c_binding, only: c_int, c_double
real(c_double), intent(in) :: sendbuf
real(c_double), intent(out) :: recvbuf
integer(c_int), intent(in) :: count, datatype, op, comm
integer(c_int), intent(out), optional :: ierror
end subroutine
function c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm) &
bind(C, name="MPI_Allreduce")
use iso_c_binding, only: c_int, c_double, c_ptr
type(c_ptr), value :: sendbuf
type(c_ptr), value :: recvbuf
integer(c_int), value :: count
type(c_ptr), value :: datatype, op, comm
integer(c_int) :: c_mpi_allreduce_scalar
end function

subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
Expand Down
4 changes: 4 additions & 0 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) {
}
}

void* get_c_mpi_inplace_from_fortran(double sendbuf) {
return MPI_IN_PLACE;
}

void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
Expand Down