Skip to content

Use MPI_Info_f2c from fortran Bind(C) #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm
end if
end function handle_mpi_comm_f2c

integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info)
use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null
integer, intent(in) :: info_f
if (info_f == MPI_INFO_NULL) then
c_info = c_mpi_info_null()
else
c_info = c_mpi_info_f2c(info_f)
end if
end function handle_mpi_info_f2c

subroutine MPI_Init_proc(ierr)
use mpi_c_bindings, only: c_mpi_init
use iso_c_binding, only : c_int, c_ptr, c_null_ptr
Expand Down Expand Up @@ -439,7 +449,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier
integer(c_int) :: local_ierr

if (sendbuf == MPI_IN_PLACE) then
sendbuf_ptr = c_mpi_in_place_f2c(sendbuf)
sendbuf_ptr = c_mpi_in_place_f2c()
else
sendbuf_ptr = c_loc(sendbuf)
end if
Expand Down Expand Up @@ -472,7 +482,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com
integer(c_int) :: local_ierr

if (sendbuf == MPI_IN_PLACE) then
sendbuf_ptr = c_mpi_in_place_f2c(sendbuf)
sendbuf_ptr = c_mpi_in_place_f2c()
else
sendbuf_ptr = c_loc(sendbuf)
end if
Expand Down Expand Up @@ -603,7 +613,7 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror)

subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_c2f, c_mpi_info_f2c
use mpi_c_bindings, only: c_mpi_comm_split_type, c_mpi_comm_c2f
integer, intent(in) :: comm
integer, intent(in) :: split_type, key
integer, intent(in) :: info
Expand All @@ -614,21 +624,21 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror
integer(kind=MPI_HANDLE_KIND) :: c_comm, c_info, c_new_comm

c_comm = handle_mpi_comm_f2c(comm)
c_info = c_mpi_info_f2c(info)
c_info = handle_mpi_info_f2c(info)

! Call the native MPI_Comm_split_type.
local_ierr = c_mpi_comm_split_type(c_comm, split_type, key, c_info, c_new_comm)

! Convert the new communicator C handle back to a Fortran integer handle.
newcomm = c_mpi_comm_c2f(c_new_comm)

if (present(ierror)) then
ierror = local_ierr
ierror = local_ierr
else
if (local_ierr /= 0) then
print *, "MPI_Comm_split_type failed with error code: ", local_ierr
end if
if (local_ierr /= 0) then
print *, "MPI_Comm_split_type failed with error code: ", local_ierr
end if
end if

end subroutine MPI_Comm_split_type_proc

subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror)
Expand Down
17 changes: 10 additions & 7 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f")
integer(c_int) :: f_status(*) ! assumed-size array
integer(c_int) :: c_mpi_status_c2f
end function c_mpi_status_c2f
function c_mpi_comm_world() bind(C, name="get_c_mpi_comm_world")

function c_mpi_comm_world() bind(C, name="get_c_MPI_COMM_WORLD")
use iso_c_binding, only: c_ptr
integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_world
end function c_mpi_comm_world

function c_mpi_info_null() bind(C, name="get_c_MPI_INFO_NULL")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_null
end function c_mpi_info_null

function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: datatype
Expand All @@ -56,8 +60,8 @@ function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran")
integer(c_int), value :: op_f
integer(kind=MPI_HANDLE_KIND) :: c_mpi_op_f2c
end function c_mpi_op_f2c
function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran")

function c_mpi_info_f2c(info_f) bind(C, name="MPI_Info_f2c")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: info_f
integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c
Expand All @@ -68,9 +72,8 @@ function c_mpi_statuses_ignore() bind(C, name="get_c_MPI_STATUSES_IGNORE")
type(c_ptr) :: c_mpi_statuses_ignore
end function c_mpi_statuses_ignore

function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran")
use iso_c_binding, only: c_double, c_ptr
real(c_double), value :: in_place_f
function c_mpi_in_place_f2c() bind(C,name="get_c_MPI_IN_PLACE")
use iso_c_binding, only: c_ptr
type(c_ptr) :: c_mpi_in_place_f2c
end function c_mpi_in_place_f2c

Expand Down
16 changes: 6 additions & 10 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ MPI_Datatype get_c_datatype_from_fortran(int datatype) {
return c_datatype;
}

MPI_Info get_c_info_from_fortran(int info) {
if (info == FORTRAN_MPI_INFO_NULL) {
return MPI_INFO_NULL;
} else {
return MPI_Info_f2c(info);
}
MPI_Info get_c_MPI_INFO_NULL() {
return MPI_INFO_NULL;
}

MPI_Op get_c_op_from_fortran(int op) {
Expand All @@ -49,14 +45,14 @@ MPI_Op get_c_op_from_fortran(int op) {
}
}

MPI_Comm get_c_mpi_comm_world() {
MPI_Comm get_c_MPI_COMM_WORLD() {
return MPI_COMM_WORLD;
}

void* get_c_mpi_inplace_from_fortran(double sendbuf) {
void* get_c_MPI_IN_PLACE() {
return MPI_IN_PLACE;
}

MPI_Status* get_c_MPI_STATUSES_IGNORE(){
MPI_Status* get_c_MPI_STATUSES_IGNORE() {
return MPI_STATUSES_IGNORE;
}
}