diff --git a/src/mpi.f90 b/src/mpi.f90 index b058d6e..4736454 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -2,7 +2,7 @@ module mpi implicit none integer, parameter :: MPI_THREAD_FUNNELED = 1 ! not sure if this is correct really - integer, parameter :: MPI_INTEGER = 0 + integer, parameter :: MPI_INTEGER = 2 integer, parameter :: MPI_REAL4 = 0 integer, parameter :: MPI_REAL8 = 1 integer, parameter :: MPI_COMM_TYPE_SHARED = 1 @@ -58,7 +58,8 @@ module mpi interface MPI_Allreduce module procedure MPI_Allreduce_scalar module procedure MPI_Allreduce_1d - module procedure MPI_Allreduce_array + module procedure MPI_Allreduce_array_real + module procedure MPI_Allreduce_array_int end interface interface MPI_Wtime @@ -282,15 +283,25 @@ subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) call c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) end subroutine - subroutine MPI_Allreduce_array(sendbuf, recvbuf, count, datatype, op, comm, ierror) - use mpi_c_bindings, only: c_mpi_allreduce_array + subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use mpi_c_bindings, only: c_mpi_allreduce_array_real ! Declare both send and recv as arrays: real(8), dimension(:), intent(in) :: sendbuf real(8), dimension(:), intent(out) :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - call c_mpi_allreduce_array(sendbuf, recvbuf, count, datatype, op, comm, ierror) - end subroutine MPI_Allreduce_array + call c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) + end subroutine MPI_Allreduce_array_real + + subroutine MPI_Allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) + use mpi_c_bindings, only: c_mpi_allreduce_array_int + ! Declare both send and recv as arrays: + integer, dimension(:), intent(in) :: sendbuf + integer, dimension(:), intent(out) :: recvbuf + integer, intent(in) :: count, datatype, op, comm + integer, intent(out), optional :: ierror + call c_mpi_allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) + end subroutine MPI_Allreduce_array_int function MPI_Wtime_proc() result(time) use mpi_c_bindings, only: c_mpi_wtime diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 48251f7..92a4188 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -92,7 +92,8 @@ subroutine c_mpi_irecv(buf, count, datatype, source, tag, comm, request, ierror) integer(c_int), optional, intent(out) :: ierror end subroutine - subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) bind(C, name="mpi_allreduce_wrapper") + subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) & + bind(C, name="mpi_allreduce_wrapper_real") use iso_c_binding, only: c_int, c_double real(c_double), intent(in) :: sendbuf real(c_double), intent(out) :: recvbuf @@ -100,7 +101,8 @@ subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, i integer(c_int), intent(out), optional :: ierror end subroutine - subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) bind(C, name="mpi_allreduce_wrapper") + subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) & + bind(C, name="mpi_allreduce_wrapper_real") use iso_c_binding, only: c_int, c_double real(c_double), intent(in) :: sendbuf real(c_double), dimension(*), intent(out) :: recvbuf @@ -108,13 +110,23 @@ subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierro integer(c_int), intent(out), optional :: ierror end subroutine - subroutine c_mpi_allreduce_array(sendbuf, recvbuf, count, datatype, op, comm, ierror) bind(C, name="mpi_allreduce_wrapper") + subroutine c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) & + bind(C, name="mpi_allreduce_wrapper_real") use iso_c_binding, only: c_int, c_double real(c_double), dimension(*), intent(in) :: sendbuf real(c_double), dimension(*), intent(out) :: recvbuf integer(c_int), intent(in) :: count, datatype, op, comm integer(c_int), intent(out), optional :: ierror - end subroutine c_mpi_allreduce_array + end subroutine c_mpi_allreduce_array_real + + subroutine c_mpi_allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) & + bind(C, name="mpi_allreduce_wrapper_int") + use iso_c_binding, only: c_int, c_double + integer(c_int), dimension(*), intent(in) :: sendbuf + integer(c_int), dimension(*), intent(out) :: recvbuf + integer(c_int), intent(in) :: count, datatype, op, comm + integer(c_int), intent(out), optional :: ierror + end subroutine c_mpi_allreduce_array_int function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime") use iso_c_binding, only: c_double diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index c9dc5bb..6fa37ef 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -19,13 +19,13 @@ void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, MPI_Comm comm = MPI_Comm_f2c(*comm_f); MPI_Datatype datatype; switch (*datatype_f) { - case 0: + case 2: datatype = MPI_INT; break; - case 1: + case 0: datatype = MPI_FLOAT; break; - case 2: + case 1: datatype = MPI_DOUBLE; break; default: @@ -59,7 +59,7 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype MPI_Datatype sendtype, recvtype; switch (*sendtype_f) { - case 0: + case 2: sendtype = MPI_INT; break; default: @@ -68,7 +68,7 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype } switch (*recvtype_f) { - case 0: + case 2: recvtype = MPI_INT; break; default: @@ -158,7 +158,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f, *request_f = MPI_Request_c2f(request); } -void mpi_allreduce_wrapper(const double *sendbuf, double *recvbuf, int *count, +void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = MPI_Comm_f2c(*comm_f); MPI_Datatype datatype; @@ -191,6 +191,21 @@ void mpi_allreduce_wrapper(const double *sendbuf, double *recvbuf, int *count, } } +void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count, + int *datatype_f, int *op_f, int *comm_f, int *ierror) { + MPI_Comm comm = MPI_Comm_f2c(*comm_f); + MPI_Datatype datatype; + datatype = MPI_INT; + + MPI_Op op = MPI_Op_f2c(*op_f); + + if (*sendbuf == -1) { + *ierror = MPI_Allreduce(MPI_IN_PLACE , recvbuf, *count, datatype, MPI_SUM, comm); + } else { + *ierror = MPI_Allreduce(sendbuf , recvbuf, *count, datatype, MPI_SUM, comm); + } +} + void mpi_barrier_wrapper(int *comm_f, int *ierror) { MPI_Comm comm = MPI_Comm_f2c(*comm_f); *ierror = MPI_Barrier(comm);