Skip to content

Handle Integer type arrays for MPI_Allreduce #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module mpi
implicit none
integer, parameter :: MPI_THREAD_FUNNELED = 1
! not sure if this is correct really
integer, parameter :: MPI_INTEGER = 0
integer, parameter :: MPI_INTEGER = 2
integer, parameter :: MPI_REAL4 = 0
integer, parameter :: MPI_REAL8 = 1
integer, parameter :: MPI_COMM_TYPE_SHARED = 1
Expand Down Expand Up @@ -58,7 +58,8 @@ module mpi
interface MPI_Allreduce
module procedure MPI_Allreduce_scalar
module procedure MPI_Allreduce_1d
module procedure MPI_Allreduce_array
module procedure MPI_Allreduce_array_real
module procedure MPI_Allreduce_array_int
end interface

interface MPI_Wtime
Expand Down Expand Up @@ -282,15 +283,25 @@ subroutine MPI_Allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror)
call c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror)
end subroutine

subroutine MPI_Allreduce_array(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_array
subroutine MPI_Allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_array_real
! Declare both send and recv as arrays:
real(8), dimension(:), intent(in) :: sendbuf
real(8), dimension(:), intent(out) :: recvbuf
integer, intent(in) :: count, datatype, op, comm
integer, intent(out), optional :: ierror
call c_mpi_allreduce_array(sendbuf, recvbuf, count, datatype, op, comm, ierror)
end subroutine MPI_Allreduce_array
call c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror)
end subroutine MPI_Allreduce_array_real

subroutine MPI_Allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use mpi_c_bindings, only: c_mpi_allreduce_array_int
! Declare both send and recv as arrays:
integer, dimension(:), intent(in) :: sendbuf
integer, dimension(:), intent(out) :: recvbuf
integer, intent(in) :: count, datatype, op, comm
integer, intent(out), optional :: ierror
call c_mpi_allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror)
end subroutine MPI_Allreduce_array_int

function MPI_Wtime_proc() result(time)
use mpi_c_bindings, only: c_mpi_wtime
Expand Down
20 changes: 16 additions & 4 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -92,29 +92,41 @@ subroutine c_mpi_irecv(buf, count, datatype, source, tag, comm, request, ierror)
integer(c_int), optional, intent(out) :: ierror
end subroutine

subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) bind(C, name="mpi_allreduce_wrapper")
subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
use iso_c_binding, only: c_int, c_double
real(c_double), intent(in) :: sendbuf
real(c_double), intent(out) :: recvbuf
integer(c_int), intent(in) :: count, datatype, op, comm
integer(c_int), intent(out), optional :: ierror
end subroutine

subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) bind(C, name="mpi_allreduce_wrapper")
subroutine c_mpi_allreduce_1d(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
use iso_c_binding, only: c_int, c_double
real(c_double), intent(in) :: sendbuf
real(c_double), dimension(*), intent(out) :: recvbuf
integer(c_int), intent(in) :: count, datatype, op, comm
integer(c_int), intent(out), optional :: ierror
end subroutine

subroutine c_mpi_allreduce_array(sendbuf, recvbuf, count, datatype, op, comm, ierror) bind(C, name="mpi_allreduce_wrapper")
subroutine c_mpi_allreduce_array_real(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
use iso_c_binding, only: c_int, c_double
real(c_double), dimension(*), intent(in) :: sendbuf
real(c_double), dimension(*), intent(out) :: recvbuf
integer(c_int), intent(in) :: count, datatype, op, comm
integer(c_int), intent(out), optional :: ierror
end subroutine c_mpi_allreduce_array
end subroutine c_mpi_allreduce_array_real

subroutine c_mpi_allreduce_array_int(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_int")
use iso_c_binding, only: c_int, c_double
integer(c_int), dimension(*), intent(in) :: sendbuf
integer(c_int), dimension(*), intent(out) :: recvbuf
integer(c_int), intent(in) :: count, datatype, op, comm
integer(c_int), intent(out), optional :: ierror
end subroutine c_mpi_allreduce_array_int

function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime")
use iso_c_binding, only: c_double
Expand Down
27 changes: 21 additions & 6 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root,
MPI_Comm comm = MPI_Comm_f2c(*comm_f);
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
case 2:
datatype = MPI_INT;
break;
case 1:
case 0:
datatype = MPI_FLOAT;
break;
case 2:
case 1:
datatype = MPI_DOUBLE;
break;
default:
Expand Down Expand Up @@ -59,7 +59,7 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype

MPI_Datatype sendtype, recvtype;
switch (*sendtype_f) {
case 0:
case 2:
sendtype = MPI_INT;
break;
default:
Expand All @@ -68,7 +68,7 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype
}

switch (*recvtype_f) {
case 0:
case 2:
recvtype = MPI_INT;
break;
default:
Expand Down Expand Up @@ -158,7 +158,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f,
*request_f = MPI_Request_c2f(request);
}

void mpi_allreduce_wrapper(const double *sendbuf, double *recvbuf, int *count,
void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = MPI_Comm_f2c(*comm_f);
MPI_Datatype datatype;
Expand Down Expand Up @@ -191,6 +191,21 @@ void mpi_allreduce_wrapper(const double *sendbuf, double *recvbuf, int *count,
}
}

void mpi_allreduce_wrapper_int(const int *sendbuf, int *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = MPI_Comm_f2c(*comm_f);
MPI_Datatype datatype;
datatype = MPI_INT;

MPI_Op op = MPI_Op_f2c(*op_f);

if (*sendbuf == -1) {
*ierror = MPI_Allreduce(MPI_IN_PLACE , recvbuf, *count, datatype, MPI_SUM, comm);
} else {
*ierror = MPI_Allreduce(sendbuf , recvbuf, *count, datatype, MPI_SUM, comm);
}
}

void mpi_barrier_wrapper(int *comm_f, int *ierror) {
MPI_Comm comm = MPI_Comm_f2c(*comm_f);
*ierror = MPI_Barrier(comm);
Expand Down