Skip to content

Feat: Implement Wrappers of MPI_COMM_CREATE and MPI_GROUP_RANGE_INCL #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module mpi
integer, parameter :: MPI_SUCCESS = 0

integer, parameter :: MPI_COMM_WORLD = -1000
integer, parameter :: MPI_COMM_NULL = -1001
real(8), parameter :: MPI_IN_PLACE = -1002
integer, parameter :: MPI_SUM = -2300
integer, parameter :: MPI_MAX = -2301
Expand Down Expand Up @@ -49,6 +50,10 @@ module mpi
module procedure MPI_Comm_Group_proc
end interface MPI_Comm_Group

interface MPI_Comm_create
module procedure MPI_Comm_create_proc
end interface MPI_Comm_create

interface MPI_Group_free
module procedure MPI_Group_free_proc
end interface MPI_Group_free
Expand All @@ -57,6 +62,11 @@ module mpi
module procedure MPI_Group_size_proc
end interface MPI_Group_size

interface MPI_Group_range_incl
module procedure MPI_Group_range_incl_proc
end interface MPI_Group_range_incl


interface MPI_Comm_dup
module procedure MPI_Comm_dup_proc
end interface MPI_Comm_dup
Expand Down Expand Up @@ -175,6 +185,16 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm
end if
end function handle_mpi_comm_f2c

integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_c2f(comm_c) result(f_comm)
use mpi_c_bindings, only: c_mpi_comm_c2f, c_mpi_comm_null
integer(kind=mpi_handle_kind), intent(in) :: comm_c
if (comm_c == c_mpi_comm_null) then
f_comm = MPI_COMM_NULL
else
f_comm = c_mpi_comm_c2f(comm_c)
end if
end function handle_mpi_comm_c2f

integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info)
use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null
integer, intent(in) :: info_f
Expand Down Expand Up @@ -350,6 +370,51 @@ subroutine MPI_Group_free_proc(group, ierror)
end if
end subroutine MPI_Group_free_proc

subroutine MPI_Group_range_incl_proc(group, n, ranks, newgroup, ierror)
use mpi_c_bindings, only: c_mpi_group_range_incl, c_mpi_group_f2c, c_mpi_comm_c2f, c_mpi_group_c2f
use iso_c_binding, only: c_int, c_ptr
integer, intent(in) :: group
integer, intent(in) :: n
integer, dimension(:,:), intent(in) :: ranks
integer, intent(out) :: newgroup
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_group, c_newgroup
integer(c_int) :: local_ierr

c_group = c_mpi_group_f2c(group)
local_ierr = c_mpi_group_range_incl(c_group, n, ranks, c_newgroup)
newgroup = c_mpi_group_c2f(c_newgroup)

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Group_incl failed with error code: ", local_ierr
end if
end subroutine MPI_Group_range_incl_proc

subroutine MPI_Comm_create_proc(comm, group, newcomm, ierror)
use mpi_c_bindings, only: c_mpi_comm_create, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_group_f2c, c_mpi_comm_null
use iso_c_binding, only: c_int, c_ptr
integer, intent(in) :: comm
integer, intent(in) :: group
integer, intent(out) :: newcomm
integer, optional, intent(out) :: ierror
integer(kind=MPI_HANDLE_KIND) :: c_comm, c_group, c_newcomm
integer(c_int) :: local_ierr

c_comm = handle_mpi_comm_f2c(comm)
c_group = c_mpi_group_f2c(group)
local_ierr = c_mpi_comm_create(c_comm, c_group, c_newcomm)

newcomm = handle_mpi_comm_c2f(c_newcomm)

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Comm_create failed with error code: ", local_ierr
end if
end subroutine MPI_Comm_create_proc

subroutine MPI_Comm_dup_proc(comm, newcomm, ierror)
use mpi_c_bindings, only: c_mpi_comm_dup, c_mpi_comm_c2f
integer, intent(in) :: comm
Expand Down
18 changes: 18 additions & 0 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module mpi_c_bindings
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
Expand Down Expand Up @@ -338,5 +339,22 @@ function c_mpi_group_free(group) bind(C, name="MPI_Group_free")
integer(c_int) :: c_mpi_group_free
end function c_mpi_group_free

function c_mpi_group_range_incl(group, n, ranges, c_newgroup) bind(C, name="MPI_Group_range_incl")
use iso_c_binding, only: c_ptr, c_int
integer(kind=MPI_HANDLE_KIND), value :: group
integer(c_int), value :: n
integer(c_int), dimension(*) :: ranges
integer(kind=MPI_HANDLE_KIND) :: c_newgroup
integer(c_int) :: c_mpi_group_range_incl
end function c_mpi_group_range_incl

function c_mpi_comm_create(comm, group, newcomm) bind(C, name="MPI_Comm_create")
use iso_c_binding, only: c_ptr, c_int
integer(kind=MPI_HANDLE_KIND), value :: comm
integer(kind=MPI_HANDLE_KIND), value :: group
integer(kind=MPI_HANDLE_KIND), intent(out) :: newcomm
integer(c_int) :: c_mpi_comm_create
end function c_mpi_comm_create

end interface
end module mpi_c_bindings
18 changes: 13 additions & 5 deletions src/mpi_constants.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@ MPI_Status* c_MPI_STATUSES_IGNORE = MPI_STATUSES_IGNORE;

MPI_Info c_MPI_INFO_NULL = MPI_INFO_NULL;

MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD;
void* c_MPI_IN_PLACE = MPI_IN_PLACE;

// DataType Declarations

MPI_Datatype c_MPI_DOUBLE = MPI_DOUBLE;

MPI_Datatype c_MPI_FLOAT = MPI_FLOAT;

MPI_Datatype c_MPI_INT = MPI_INT;

void* c_MPI_IN_PLACE = MPI_IN_PLACE;
MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL;

MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER;

MPI_Datatype c_MPI_REAL = MPI_REAL;

// Operation Declarations

MPI_Op c_MPI_SUM = MPI_SUM;

MPI_Op c_MPI_MAX = MPI_MAX;

MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL;
// Communicators Declarations

MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER;
MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL;

MPI_Datatype c_MPI_REAL = MPI_REAL;
MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD;
43 changes: 43 additions & 0 deletions tests/comm_create_1.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
program minimal_mre_range
use mpi
implicit none

integer :: ierr, rank, new_rank, size
integer :: group_world, group_range, new_comm
integer, dimension(1,3) :: range ! 1D array to define a single range
integer :: i

call MPI_INIT(ierr)
call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr)
call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr)

! Get the group of MPI_COMM_WORLD
call MPI_COMM_GROUP(MPI_COMM_WORLD, group_world, ierr)

! Define 1D range: start, end, stride
range(1,1) = 0 ! start
range(1,2) = size - 1 ! end
range(1,3) = 1 ! stride


! Create a new group that includes all ranks
call MPI_GROUP_RANGE_INCL(group_world, 1, range, group_range, ierr)

! Create new communicator
call MPI_COMM_CREATE(MPI_COMM_WORLD, group_range, new_comm, ierr)

! Print participation
if (new_comm /= MPI_COMM_NULL) then
call MPI_COMM_RANK(new_comm, new_rank, ierr)
if (ierr /= MPI_SUCCESS) error stop "MPI_COMM_RANK on new_comm failed"
print *, 'Global rank', rank, 'is in new_comm with local rank', new_rank
else
print *, 'Rank', rank, 'is NOT in the new communicator.'
end if

! Free groups (no comm_free)
call MPI_GROUP_FREE(group_range, ierr)
call MPI_GROUP_FREE(group_world, ierr)

call MPI_FINALIZE(ierr)
end program minimal_mre_range