Skip to content

simplify C wrapper of MPI_Op_f2c #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ module mpi

contains

integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op)
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum
integer, intent(in) :: op_f
if (op_f == MPI_SUM) then
c_op = c_mpi_sum()
else
c_op = c_mpi_op_f2c(op_f)
end if
end function

integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm)
use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world
integer, intent(in) :: comm_f
Expand Down Expand Up @@ -439,7 +449,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr

subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c
real(8), intent(in), target :: sendbuf
real(8), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -455,7 +465,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier
end if
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)

Expand All @@ -472,7 +482,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier

subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c
real(8), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -489,7 +499,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com

recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)

Expand All @@ -506,7 +516,7 @@ end subroutine MPI_Allreduce_1D_recv_proc

subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c
real(8), dimension(:), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -518,7 +528,7 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com
sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)

Expand All @@ -535,7 +545,7 @@ end subroutine MPI_Allreduce_1D_real_proc

subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, &
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, &
c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world
integer, dimension(:), intent(in), target :: sendbuf
integer, dimension(:), intent(out), target :: recvbuf
Expand All @@ -548,7 +558,7 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm
sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)

Expand Down Expand Up @@ -914,7 +924,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror)
end subroutine

subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c, c_mpi_op_f2c
use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c
use iso_c_binding, only: c_int, c_ptr, c_loc
integer, target, intent(in) :: sendbuf
integer, target, intent(out) :: recvbuf
Expand All @@ -928,7 +938,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co
c_comm = handle_mpi_comm_f2c(comm)

c_dtype = c_mpi_datatype_f2c(datatype)
c_op = c_mpi_op_f2c(op)
c_op = handle_mpi_op_f2c(op)

! Pass pointer to the actual data
c_sendbuf = c_loc(sendbuf)
Expand Down
6 changes: 5 additions & 1 deletion src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,17 @@ function c_mpi_info_null() bind(C, name="get_c_MPI_INFO_NULL")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_null
end function c_mpi_info_null

function c_mpi_sum() bind(C, name="get_c_MPI_SUM")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_sum
end function c_mpi_sum

function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: datatype
integer(kind=MPI_HANDLE_KIND) :: c_mpi_datatype_f2c
end function c_mpi_datatype_f2c

function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran")
function c_mpi_op_f2c(op_f) bind(C, name="MPI_Op_f2c")
use iso_c_binding, only: c_ptr, c_int
integer(c_int), value :: op_f
integer(kind=MPI_HANDLE_KIND) :: c_mpi_op_f2c
Expand Down
8 changes: 2 additions & 6 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,8 @@ MPI_Info get_c_MPI_INFO_NULL() {
return MPI_INFO_NULL;
}

MPI_Op get_c_op_from_fortran(int op) {
if (op == FORTRAN_MPI_SUM) {
return MPI_SUM;
} else {
return MPI_Op_f2c(op);
}
MPI_Op get_c_MPI_SUM() {
return MPI_SUM;
}

MPI_Comm get_c_MPI_COMM_WORLD() {
Expand Down