diff --git a/src/mpi.f90 b/src/mpi.f90 index d5896d5..d8e64f2 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -17,6 +17,7 @@ module mpi integer, parameter :: MPI_SUCCESS = 0 integer, parameter :: MPI_COMM_WORLD = -1000 + integer, parameter :: MPI_COMM_NULL = -1001 real(8), parameter :: MPI_IN_PLACE = -1002 integer, parameter :: MPI_SUM = -2300 integer, parameter :: MPI_MAX = -2301 @@ -49,6 +50,10 @@ module mpi module procedure MPI_Comm_Group_proc end interface MPI_Comm_Group + interface MPI_Comm_create + module procedure MPI_Comm_create_proc + end interface MPI_Comm_create + interface MPI_Group_free module procedure MPI_Group_free_proc end interface MPI_Group_free @@ -57,6 +62,11 @@ module mpi module procedure MPI_Group_size_proc end interface MPI_Group_size + interface MPI_Group_range_incl + module procedure MPI_Group_range_incl_proc + end interface MPI_Group_range_incl + + interface MPI_Comm_dup module procedure MPI_Comm_dup_proc end interface MPI_Comm_dup @@ -175,6 +185,16 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_f2c(comm_f) result(c_comm end if end function handle_mpi_comm_f2c + integer(kind=MPI_HANDLE_KIND) function handle_mpi_comm_c2f(comm_c) result(f_comm) + use mpi_c_bindings, only: c_mpi_comm_c2f, c_mpi_comm_null + integer(kind=mpi_handle_kind), intent(in) :: comm_c + if (comm_c == c_mpi_comm_null) then + f_comm = MPI_COMM_NULL + else + f_comm = c_mpi_comm_c2f(comm_c) + end if + end function handle_mpi_comm_c2f + integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info) use mpi_c_bindings, only: c_mpi_info_f2c, c_mpi_info_null integer, intent(in) :: info_f @@ -350,6 +370,51 @@ subroutine MPI_Group_free_proc(group, ierror) end if end subroutine MPI_Group_free_proc + subroutine MPI_Group_range_incl_proc(group, n, ranks, newgroup, ierror) + use mpi_c_bindings, only: c_mpi_group_range_incl, c_mpi_group_f2c, c_mpi_comm_c2f, c_mpi_group_c2f + use iso_c_binding, only: c_int, c_ptr + integer, intent(in) :: group + integer, intent(in) :: n + integer, dimension(:,:), intent(in) :: ranks + integer, intent(out) :: newgroup + integer, optional, intent(out) :: ierror + integer(kind=MPI_HANDLE_KIND) :: c_group, c_newgroup + integer(c_int) :: local_ierr + + c_group = c_mpi_group_f2c(group) + local_ierr = c_mpi_group_range_incl(c_group, n, ranks, c_newgroup) + newgroup = c_mpi_group_c2f(c_newgroup) + + if (present(ierror)) then + ierror = local_ierr + else if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Group_incl failed with error code: ", local_ierr + end if + end subroutine MPI_Group_range_incl_proc + + subroutine MPI_Comm_create_proc(comm, group, newcomm, ierror) + use mpi_c_bindings, only: c_mpi_comm_create, c_mpi_comm_f2c, c_mpi_comm_c2f, c_mpi_group_f2c, c_mpi_comm_null + use iso_c_binding, only: c_int, c_ptr + integer, intent(in) :: comm + integer, intent(in) :: group + integer, intent(out) :: newcomm + integer, optional, intent(out) :: ierror + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_group, c_newcomm + integer(c_int) :: local_ierr + + c_comm = handle_mpi_comm_f2c(comm) + c_group = c_mpi_group_f2c(group) + local_ierr = c_mpi_comm_create(c_comm, c_group, c_newcomm) + + newcomm = handle_mpi_comm_c2f(c_newcomm) + + if (present(ierror)) then + ierror = local_ierr + else if (local_ierr /= MPI_SUCCESS) then + print *, "MPI_Comm_create failed with error code: ", local_ierr + end if + end subroutine MPI_Comm_create_proc + subroutine MPI_Comm_dup_proc(comm, newcomm, ierror) use mpi_c_bindings, only: c_mpi_comm_dup, c_mpi_comm_c2f integer, intent(in) :: comm diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 47ac067..2b98f6a 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -17,6 +17,7 @@ module mpi_c_bindings integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world + integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical @@ -338,5 +339,22 @@ function c_mpi_group_free(group) bind(C, name="MPI_Group_free") integer(c_int) :: c_mpi_group_free end function c_mpi_group_free + function c_mpi_group_range_incl(group, n, ranges, c_newgroup) bind(C, name="MPI_Group_range_incl") + use iso_c_binding, only: c_ptr, c_int + integer(kind=MPI_HANDLE_KIND), value :: group + integer(c_int), value :: n + integer(c_int), dimension(*) :: ranges + integer(kind=MPI_HANDLE_KIND) :: c_newgroup + integer(c_int) :: c_mpi_group_range_incl + end function c_mpi_group_range_incl + + function c_mpi_comm_create(comm, group, newcomm) bind(C, name="MPI_Comm_create") + use iso_c_binding, only: c_ptr, c_int + integer(kind=MPI_HANDLE_KIND), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: group + integer(kind=MPI_HANDLE_KIND), intent(out) :: newcomm + integer(c_int) :: c_mpi_comm_create + end function c_mpi_comm_create + end interface end module mpi_c_bindings diff --git a/src/mpi_constants.c b/src/mpi_constants.c index 48b8929..fde0534 100644 --- a/src/mpi_constants.c +++ b/src/mpi_constants.c @@ -4,7 +4,9 @@ MPI_Status* c_MPI_STATUSES_IGNORE = MPI_STATUSES_IGNORE; MPI_Info c_MPI_INFO_NULL = MPI_INFO_NULL; -MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD; +void* c_MPI_IN_PLACE = MPI_IN_PLACE; + +// DataType Declarations MPI_Datatype c_MPI_DOUBLE = MPI_DOUBLE; @@ -12,14 +14,20 @@ MPI_Datatype c_MPI_FLOAT = MPI_FLOAT; MPI_Datatype c_MPI_INT = MPI_INT; -void* c_MPI_IN_PLACE = MPI_IN_PLACE; +MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL; + +MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER; + +MPI_Datatype c_MPI_REAL = MPI_REAL; + +// Operation Declarations MPI_Op c_MPI_SUM = MPI_SUM; MPI_Op c_MPI_MAX = MPI_MAX; -MPI_Datatype c_MPI_LOGICAL = MPI_LOGICAL; +// Communicators Declarations -MPI_Datatype c_MPI_CHARACTER = MPI_CHARACTER; +MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL; -MPI_Datatype c_MPI_REAL = MPI_REAL; +MPI_Comm c_MPI_COMM_WORLD = MPI_COMM_WORLD; \ No newline at end of file diff --git a/tests/comm_create_1.f90 b/tests/comm_create_1.f90 new file mode 100644 index 0000000..f08e781 --- /dev/null +++ b/tests/comm_create_1.f90 @@ -0,0 +1,43 @@ +program minimal_mre_range + use mpi + implicit none + + integer :: ierr, rank, new_rank, size + integer :: group_world, group_range, new_comm + integer, dimension(1,3) :: range ! 1D array to define a single range + integer :: i + + call MPI_INIT(ierr) + call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr) + call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr) + + ! Get the group of MPI_COMM_WORLD + call MPI_COMM_GROUP(MPI_COMM_WORLD, group_world, ierr) + + ! Define 1D range: start, end, stride + range(1,1) = 0 ! start + range(1,2) = size - 1 ! end + range(1,3) = 1 ! stride + + + ! Create a new group that includes all ranks + call MPI_GROUP_RANGE_INCL(group_world, 1, range, group_range, ierr) + + ! Create new communicator + call MPI_COMM_CREATE(MPI_COMM_WORLD, group_range, new_comm, ierr) + + ! Print participation + if (new_comm /= MPI_COMM_NULL) then + call MPI_COMM_RANK(new_comm, new_rank, ierr) + if (ierr /= MPI_SUCCESS) error stop "MPI_COMM_RANK on new_comm failed" + print *, 'Global rank', rank, 'is in new_comm with local rank', new_rank + else + print *, 'Rank', rank, 'is NOT in the new communicator.' + end if + + ! Free groups (no comm_free) + call MPI_GROUP_FREE(group_range, ierr) + call MPI_GROUP_FREE(group_world, ierr) + + call MPI_FINALIZE(ierr) +end program minimal_mre_range \ No newline at end of file