Skip to content

use MPI_* as global variable with bind(C) #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op)
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum
integer, intent(in) :: op_f
if (op_f == MPI_SUM) then
c_op = c_mpi_sum()
c_op = c_mpi_sum
else
c_op = c_mpi_op_f2c(op_f)
end if
Expand All @@ -141,7 +141,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm
use mpi_c_bindings, only: c_mpi_comm_size, c_mpi_comm_f2c, c_mpi_comm_world
integer, intent(in) :: comm_f
if (comm_f == MPI_COMM_WORLD) then
c_comm = c_mpi_comm_world()
c_comm = c_mpi_comm_world
else
c_comm = c_mpi_comm_f2c(comm_f)
end if
Expand All @@ -151,7 +151,7 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info
use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null
integer, intent(in) :: info_f
if (info_f == MPI_INFO_NULL) then
c_info = c_mpi_info_null()
c_info = c_mpi_info_null
else
c_info = c_mpi_info_f2c(info_f)
end if
Expand All @@ -161,11 +161,11 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) resul
use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int
integer, intent(in) :: datatype_f
if (datatype_f == MPI_REAL4) then
c_datatype = c_mpi_float()
c_datatype = c_mpi_float
else if (datatype_f == MPI_REAL8 .OR. datatype_f == MPI_DOUBLE_PRECISION) then
c_datatype = c_mpi_double()
c_datatype = c_mpi_double
else if (datatype_f == MPI_INTEGER) then
c_datatype = c_mpi_int()
c_datatype = c_mpi_int
end if
end function

Expand Down Expand Up @@ -280,7 +280,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror)
end subroutine MPI_Bcast_int_scalar

subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_comm_world
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c
use iso_c_binding, only: c_int, c_ptr, c_loc
real(8), dimension(:, :), target :: buffer
integer, intent(in) :: count, root
Expand Down Expand Up @@ -461,7 +461,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr

subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place
real(8), intent(in), target :: sendbuf
real(8), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -471,7 +471,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier
integer(c_int) :: local_ierr

if (sendbuf == MPI_IN_PLACE) then
sendbuf_ptr = c_mpi_in_place_f2c()
sendbuf_ptr = c_mpi_in_place
else
sendbuf_ptr = c_loc(sendbuf)
end if
Expand All @@ -494,7 +494,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier

subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place
real(8), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -504,7 +504,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com
integer(c_int) :: local_ierr

if (sendbuf == MPI_IN_PLACE) then
sendbuf_ptr = c_mpi_in_place_f2c()
sendbuf_ptr = c_mpi_in_place
else
sendbuf_ptr = c_loc(sendbuf)
end if
Expand All @@ -528,7 +528,7 @@ end subroutine MPI_Allreduce_1D_recv_proc

subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce
real(8), dimension(:), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand Down Expand Up @@ -557,8 +557,7 @@ end subroutine MPI_Allreduce_1D_real_proc

subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, &
c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_comm_f2c
integer, dimension(:), intent(in), target :: sendbuf
integer, dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand Down Expand Up @@ -754,7 +753,7 @@ subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)
integer(kind=MPI_HANDLE_KIND), dimension(count) :: c_requests
type(c_ptr) :: MPI_STATUSES_IGNORE_from_c

MPI_STATUSES_IGNORE_from_c = c_mpi_statuses_ignore()
MPI_STATUSES_IGNORE_from_c = c_mpi_statuses_ignore

! Convert Fortran requests to C requests.
do i = 1, count
Expand Down Expand Up @@ -859,19 +858,15 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror)

subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c, c_mpi_comm_world
use mpi_c_bindings, only: c_mpi_cart_shift, c_mpi_comm_f2c
integer, intent(in) :: comm
integer, intent(in) :: direction, disp
integer, intent(out) :: rank_source, rank_dest
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_comm
integer(c_int) :: local_ierr

if (comm == MPI_COMM_WORLD) then
c_comm = c_mpi_comm_world()
else
c_comm = c_mpi_comm_f2c(comm)
end if
c_comm = handle_mpi_comm_f2c(comm)

local_ierr = c_mpi_cart_shift(c_comm, direction, disp, rank_source, rank_dest)

Expand Down
47 changes: 11 additions & 36 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module mpi_c_bindings
use iso_c_binding, only: c_ptr
implicit none

#ifdef OPEN_MPI
Expand All @@ -7,8 +8,17 @@ module mpi_c_bindings
#define MPI_HANDLE_KIND 4
#endif

type(c_ptr), bind(C, name="c_MPI_STATUSES_IGNORE") :: c_mpi_statuses_ignore
type(c_ptr), bind(C, name="c_MPI_IN_PLACE") :: c_mpi_in_place
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INFO_NULL") :: c_mpi_info_null
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_DOUBLE") :: c_mpi_double
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_FLOAT") :: c_mpi_float
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum

interface

function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: comm_f
Expand Down Expand Up @@ -40,31 +50,6 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f")
integer(c_int) :: c_mpi_status_c2f
end function c_mpi_status_c2f

function c_mpi_comm_world() bind(C, name="get_c_MPI_COMM_WORLD")
use iso_c_binding, only: c_ptr
integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_world
end function c_mpi_comm_world

function c_mpi_info_null() bind(C, name="get_c_MPI_INFO_NULL")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_null
end function c_mpi_info_null

function c_mpi_sum() bind(C, name="get_c_MPI_SUM")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_sum
end function c_mpi_sum

function c_mpi_float() bind(C, name="get_c_MPI_FLOAT")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_float
end function

function c_mpi_double() bind(C, name="get_c_MPI_DOUBLE")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_double
end function

function c_mpi_int() bind(C, name="get_c_MPI_INT")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_int
end function

function c_mpi_op_f2c(op_f) bind(C, name="MPI_Op_f2c")
use iso_c_binding, only: c_ptr, c_int
integer(c_int), value :: op_f
Expand All @@ -77,16 +62,6 @@ function c_mpi_info_f2c(info_f) bind(C, name="MPI_Info_f2c")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c
end function c_mpi_info_f2c

function c_mpi_statuses_ignore() bind(C, name="get_c_MPI_STATUSES_IGNORE")
use iso_c_binding, only: c_ptr
type(c_ptr) :: c_mpi_statuses_ignore
end function c_mpi_statuses_ignore

function c_mpi_in_place_f2c() bind(C,name="get_c_MPI_IN_PLACE")
use iso_c_binding, only: c_ptr
type(c_ptr) :: c_mpi_in_place_f2c
end function c_mpi_in_place_f2c

function c_mpi_init(argc, argv) bind(C, name="MPI_Init")
use iso_c_binding, only : c_int, c_ptr
!> TODO: is the intent need to be explicitly specified
Expand Down
32 changes: 8 additions & 24 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
@@ -1,33 +1,17 @@
#include <mpi.h>

MPI_Datatype get_c_MPI_DOUBLE() {
return MPI_DOUBLE;
}
MPI_Status* c_MPI_STATUSES_IGNORE = MPI_STATUSES_IGNORE;

MPI_Datatype get_c_MPI_FLOAT() {
return MPI_FLOAT;
}
MPI_Info c_MPI_INFO_NULL = MPI_INFO_NULL;

MPI_Datatype get_c_MPI_INT() {
return MPI_INT;
}
MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD;

MPI_Info get_c_MPI_INFO_NULL() {
return MPI_INFO_NULL;
}
MPI_Datatype c_MPI_DOUBLE = MPI_DOUBLE;

MPI_Op get_c_MPI_SUM() {
return MPI_SUM;
}
MPI_Datatype c_MPI_FLOAT = MPI_FLOAT;

MPI_Comm get_c_MPI_COMM_WORLD() {
return MPI_COMM_WORLD;
}
MPI_Datatype c_MPI_INT = MPI_INT;

void* get_c_MPI_IN_PLACE() {
return MPI_IN_PLACE;
}
void* c_MPI_IN_PLACE = MPI_IN_PLACE;

MPI_Status* get_c_MPI_STATUSES_IGNORE() {
return MPI_STATUSES_IGNORE;
}
MPI_Op c_MPI_SUM = MPI_SUM;