From 5efe0b7ff328e1f8093733fd47e76e3b5ff47e1e Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Fri, 11 Apr 2025 20:42:05 +0530 Subject: [PATCH] simplify C wrapper --- src/mpi.f90 | 30 ++++++++++++++++++++---------- src/mpi_c_bindings.f90 | 6 +++++- src/mpi_wrapper.c | 8 ++------ 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 7a6bed4..58edb1c 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -127,6 +127,16 @@ module mpi contains + integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op) + use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum + integer, intent(in) :: op_f + if (op_f == MPI_SUM) then + c_op = c_mpi_sum() + else + c_op = c_mpi_op_f2c(op_f) + end if + end function + integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm) use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world integer, intent(in) :: comm_f @@ -439,7 +449,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -455,7 +465,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier end if recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) - c_op = c_mpi_op_f2c(op) + c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -472,7 +482,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c real(8), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -489,7 +499,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) - c_op = c_mpi_op_f2c(op) + c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -506,7 +516,7 @@ end subroutine MPI_Allreduce_1D_recv_proc subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, c_mpi_in_place_f2c + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c real(8), dimension(:), intent(in), target :: sendbuf real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm @@ -518,7 +528,7 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com sendbuf_ptr = c_loc(sendbuf) recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) - c_op = c_mpi_op_f2c(op) + c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -535,7 +545,7 @@ end subroutine MPI_Allreduce_1D_real_proc subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc - use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_op_f2c, & + use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, & c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world integer, dimension(:), intent(in), target :: sendbuf integer, dimension(:), intent(out), target :: recvbuf @@ -548,7 +558,7 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm sendbuf_ptr = c_loc(sendbuf) recvbuf_ptr = c_loc(recvbuf) c_datatype = c_mpi_datatype_f2c(datatype) - c_op = c_mpi_op_f2c(op) + c_op = handle_mpi_op_f2c(op) c_comm = handle_mpi_comm_f2c(comm) @@ -914,7 +924,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) end subroutine subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror) - use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c, c_mpi_op_f2c + use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c use iso_c_binding, only: c_int, c_ptr, c_loc integer, target, intent(in) :: sendbuf integer, target, intent(out) :: recvbuf @@ -928,7 +938,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co c_comm = handle_mpi_comm_f2c(comm) c_dtype = c_mpi_datatype_f2c(datatype) - c_op = c_mpi_op_f2c(op) + c_op = handle_mpi_op_f2c(op) ! Pass pointer to the actual data c_sendbuf = c_loc(sendbuf) diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 2c0fa2f..f2fb558 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -49,13 +49,17 @@ function c_mpi_info_null() bind(C, name="get_c_MPI_INFO_NULL") integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_null end function c_mpi_info_null + function c_mpi_sum() bind(C, name="get_c_MPI_SUM") + integer(kind=MPI_HANDLE_KIND) :: c_mpi_sum + end function c_mpi_sum + function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: datatype integer(kind=MPI_HANDLE_KIND) :: c_mpi_datatype_f2c end function c_mpi_datatype_f2c - function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") + function c_mpi_op_f2c(op_f) bind(C, name="MPI_Op_f2c") use iso_c_binding, only: c_ptr, c_int integer(c_int), value :: op_f integer(kind=MPI_HANDLE_KIND) :: c_mpi_op_f2c diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index d37b6cb..e61ea52 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -37,12 +37,8 @@ MPI_Info get_c_MPI_INFO_NULL() { return MPI_INFO_NULL; } -MPI_Op get_c_op_from_fortran(int op) { - if (op == FORTRAN_MPI_SUM) { - return MPI_SUM; - } else { - return MPI_Op_f2c(op); - } +MPI_Op get_c_MPI_SUM() { + return MPI_SUM; } MPI_Comm get_c_MPI_COMM_WORLD() {