diff --git a/src/mpi.f90 b/src/mpi.f90 index 0f02606..f4c6a62 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -399,12 +399,34 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr end subroutine subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_scalar - real(8), intent(in) :: sendbuf - real(8), intent(out) :: recvbuf + use iso_c_binding, only: c_int, c_ptr, c_loc + use mpi_c_bindings, only: c_mpi_allreduce_scalar, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_comm_f2c, c_mpi_in_place_f2c + real(8), intent(in), target :: sendbuf + real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + integer(c_int) :: local_ierr + + if (sendbuf == MPI_IN_PLACE) then + sendbuf_ptr = c_mpi_in_place_f2c(sendbuf) + else + sendbuf_ptr = c_loc(sendbuf) + end if + recvbuf_ptr = c_loc(recvbuf) + c_datatype = c_mpi_datatype_f2c(datatype) + c_op = c_mpi_op_f2c(op) + c_comm = c_mpi_comm_f2c(comm) + + local_ierr = c_mpi_allreduce_scalar(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Allreduce_scalar failed with error code: ", local_ierr + end if + end if end subroutine subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index ff36210..e13303a 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -38,6 +38,12 @@ function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") type(c_ptr) :: c_mpi_info_f2c end function c_mpi_info_f2c + function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran") + use iso_c_binding, only: c_double, c_ptr + real(c_double), value :: in_place_f + type(c_ptr) :: c_mpi_in_place_f2c + end function c_mpi_in_place_f2c + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified @@ -120,14 +126,15 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n integer(c_int) :: c_mpi_irecv end function - subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) & - bind(C, name="mpi_allreduce_wrapper_real") - use iso_c_binding, only: c_int, c_double - real(c_double), intent(in) :: sendbuf - real(c_double), intent(out) :: recvbuf - integer(c_int), intent(in) :: count, datatype, op, comm - integer(c_int), intent(out), optional :: ierror - end subroutine + function c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm) & + bind(C, name="MPI_Allreduce") + use iso_c_binding, only: c_int, c_double, c_ptr + type(c_ptr), value :: sendbuf + type(c_ptr), value :: recvbuf + integer(c_int), value :: count + type(c_ptr), value :: datatype, op, comm + integer(c_int) :: c_mpi_allreduce_scalar + end function subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) & bind(C, name="mpi_allreduce_wrapper_real") diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 8ebe295..5137a66 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -57,6 +57,10 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { } } +void* get_c_mpi_inplace_from_fortran(double sendbuf) { + return MPI_IN_PLACE; +} + void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f);