Skip to content

remove C MPI wrapper function for MPI_Irecv #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,31 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror)
end subroutine

subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror)
use mpi_c_bindings, only: c_mpi_irecv
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_irecv, c_mpi_comm_f2c, get_c_datatype_from_fortran, c_mpi_request_c2f
real(8), dimension(:,:) :: buf
integer, intent(in) :: count, source, tag
integer, intent(in) :: datatype
integer, intent(in) :: comm
integer, intent(out) :: request
integer, optional, intent(out) :: ierror
call c_mpi_irecv(buf, count, datatype, source, tag, comm, request, ierror)
type(c_ptr) :: c_comm
integer(c_int) :: local_ierr
type(c_ptr) :: c_datatype
type(c_ptr) :: c_request

c_comm = c_mpi_comm_f2c(comm)
c_datatype = get_c_datatype_from_fortran(datatype)
local_ierr = c_mpi_irecv(buf, count, c_datatype, source, tag, c_comm, c_request)
request = c_mpi_request_c2f(c_request)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Irecv failed with error code: ", local_ierr
end if
end if
end subroutine

subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
Expand Down
32 changes: 22 additions & 10 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@ end function c_mpi_comm_f2c
function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f")
use iso_c_binding, only: c_int, c_ptr
type(c_ptr), value :: comm_c
integer :: c_mpi_comm_c2f
integer(c_int) :: c_mpi_comm_c2f
end function

function get_c_datatype_from_fortran(datatype) bind(C, name="get_c_datatype_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: datatype
type(c_ptr) :: get_c_datatype_from_fortran
end function get_c_datatype_from_fortran

function c_mpi_request_c2f(request) bind(C, name="MPI_Request_c2f")
use iso_c_binding, only: c_int, c_ptr
type(c_ptr), value :: request
integer(c_int) :: c_mpi_request_c2f
end function

function c_mpi_init(argc, argv) bind(C, name="MPI_Init")
Expand Down Expand Up @@ -94,15 +106,15 @@ subroutine c_mpi_isend(buf, count, datatype, dest, tag, comm, request, ierror) b
integer(c_int), optional, intent(out) :: ierror
end subroutine

subroutine c_mpi_irecv(buf, count, datatype, source, tag, comm, request, ierror) bind(C, name="mpi_irecv_wrapper")
use iso_c_binding, only: c_int, c_double
real(c_double), dimension(*) :: buf
integer(c_int), intent(in) :: count, source, tag
integer(c_int), intent(in) :: datatype
integer(c_int), intent(in) :: comm
integer(c_int), intent(out) :: request
integer(c_int), optional, intent(out) :: ierror
end subroutine
function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, name="MPI_Irecv")
use iso_c_binding, only: c_int, c_double, c_ptr
real(c_double), dimension(*), intent(out) :: buf
integer(c_int), value :: count, source, tag
type(c_ptr), value :: datatype
type(c_ptr), value :: comm
type(c_ptr), intent(out) :: request
integer(c_int) :: c_mpi_irecv
end function

subroutine c_mpi_allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror) &
bind(C, name="mpi_allreduce_wrapper_real")
Expand Down
13 changes: 1 addition & 12 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,6 @@ void mpi_isend_wrapper(const double *buf, int *count, int *datatype_f,
*request_f = MPI_Request_c2f(request);
}

void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f,
int *source, int *tag, int *comm_f, int *request_f,
int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Request request;
*ierror = MPI_Irecv(buf, *count, datatype, *source, *tag, comm, &request);
*request_f = MPI_Request_c2f(request);
}

void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
Expand Down Expand Up @@ -200,7 +189,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f,

void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest,
int *tag, int *comm_f, int *ierror) {
MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f);
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
*ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm);
Expand Down