From be59f87b37dbf1d338ffea7b6241ed667883d071 Mon Sep 17 00:00:00 2001 From: Aditya Trivedi Date: Tue, 27 May 2025 11:13:27 +0530 Subject: [PATCH] FEAT: Implement Wrappers of MPI_COMM_GROUP, MPI_GROUP_SIZE and MPI_GROUP_FREE --- src/mpi.f90 | 76 ++++++++++++++++++++++++++++++++++++++++++ src/mpi_c_bindings.f90 | 32 ++++++++++++++++++ tests/comm_group_1.f90 | 34 +++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 tests/comm_group_1.f90 diff --git a/src/mpi.f90 b/src/mpi.f90 index a2e37ef..d5896d5 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -45,6 +45,18 @@ module mpi module procedure MPI_Comm_size_proc end interface MPI_Comm_size + interface MPI_Comm_Group + module procedure MPI_Comm_Group_proc + end interface MPI_Comm_Group + + interface MPI_Group_free + module procedure MPI_Group_free_proc + end interface MPI_Group_free + + interface MPI_Group_size + module procedure MPI_Group_size_proc + end interface MPI_Group_size + interface MPI_Comm_dup module procedure MPI_Comm_dup_proc end interface MPI_Comm_dup @@ -274,6 +286,70 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) end if end subroutine + subroutine MPI_Comm_Group_proc(comm, group, ierror) + use mpi_c_bindings, only: c_mpi_comm_group, c_mpi_group_f2c, c_mpi_group_c2f + use iso_c_binding, only: c_int, c_ptr + integer, intent(in) :: comm + integer, intent(out) :: group + integer, optional, intent(out) :: ierror + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_group + integer :: local_ierr + + c_comm = handle_mpi_comm_f2c(comm) + c_group = c_mpi_group_f2c(group) + local_ierr = c_mpi_comm_group(c_comm, c_group) + group = c_mpi_group_c2f(c_group) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= 0) then + print *, "MPI_Comm_Group failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Comm_Group_proc + + subroutine MPI_Group_size_proc(group, size, ierror) + use mpi_c_bindings, only: c_mpi_group_size, c_mpi_group_f2c + use iso_c_binding, only: c_int, c_ptr + integer, intent(in) :: group + integer, intent(out) :: size + integer, optional, intent(out) :: ierror + integer(kind=MPI_HANDLE_KIND) :: c_group + integer :: local_ierr + + c_group = c_mpi_group_f2c(group) + local_ierr = c_mpi_group_size(c_group, size) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= 0) then + print *, "MPI_Group_size failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Group_size_proc + + subroutine MPI_Group_free_proc(group, ierror) + use mpi_c_bindings, only: c_mpi_group_free, c_mpi_group_f2c + use iso_c_binding, only: c_int, c_ptr + integer, intent(in) :: group + integer, optional, intent(out) :: ierror + integer(kind=MPI_HANDLE_KIND) :: c_group + integer :: local_ierr + + c_group = c_mpi_group_f2c(group) + local_ierr = c_mpi_group_free(c_group) + + if (present(ierror)) then + ierror = local_ierr + else + if (local_ierr /= 0) then + print *, "MPI_Group_free failed with error code: ", local_ierr + end if + end if + end subroutine MPI_Group_free_proc + subroutine MPI_Comm_dup_proc(comm, newcomm, ierror) use mpi_c_bindings, only: c_mpi_comm_dup, c_mpi_comm_c2f integer, intent(in) :: comm diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index ac350c1..47ac067 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -67,6 +67,18 @@ function c_mpi_info_f2c(info_f) bind(C, name="MPI_Info_f2c") integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c end function c_mpi_info_f2c + function c_mpi_group_f2c(group_f) bind(C, name="MPI_Group_f2c") + use iso_c_binding, only: c_int, c_ptr + integer(c_int), value :: group_f + integer(kind=MPI_HANDLE_KIND) :: c_mpi_group_f2c + end function c_mpi_group_f2c + + function c_mpi_group_c2f(group_c) bind(C, name="MPI_Group_c2f") + use iso_c_binding, only: c_int, c_ptr + integer(kind=MPI_HANDLE_KIND), value :: group_c + integer(c_int) :: c_mpi_group_c2f + end function c_mpi_group_c2f + function c_mpi_init(argc, argv) bind(C, name="MPI_Init") use iso_c_binding, only : c_int, c_ptr !> TODO: is the intent need to be explicitly specified @@ -306,5 +318,25 @@ function c_mpi_reduce(sendbuf, recvbuf, count, c_dtype, c_op, root, c_comm) & integer(c_int) :: c_mpi_reduce end function c_mpi_reduce + function c_mpi_comm_group(comm, group) bind(C, name="MPI_Comm_group") + use iso_c_binding, only: c_ptr, c_int + integer(kind=MPI_HANDLE_KIND), value :: comm + integer(kind=MPI_HANDLE_KIND), intent(out) :: group + integer(c_int) :: c_mpi_comm_group + end function c_mpi_comm_group + + function c_mpi_group_size(group, size) bind(C, name="MPI_Group_size") + use iso_c_binding, only: c_ptr, c_int + integer(kind=MPI_HANDLE_KIND), value :: group + integer(c_int), intent(out) :: size + integer(c_int) :: c_mpi_group_size + end function c_mpi_group_size + + function c_mpi_group_free(group) bind(C, name="MPI_Group_free") + use iso_c_binding, only: c_ptr, c_int + integer(kind=MPI_HANDLE_KIND), intent(in) :: group + integer(c_int) :: c_mpi_group_free + end function c_mpi_group_free + end interface end module mpi_c_bindings diff --git a/tests/comm_group_1.f90 b/tests/comm_group_1.f90 new file mode 100644 index 0000000..0306c87 --- /dev/null +++ b/tests/comm_group_1.f90 @@ -0,0 +1,34 @@ +program comm_group_1 + use mpi + implicit none + integer :: ierr, rank, size, group, group_size + logical :: error + + ! Initialize MPI + call MPI_Init(ierr) + call MPI_Comm_rank(MPI_COMM_WORLD, rank, ierr) + call MPI_Comm_size(MPI_COMM_WORLD, size, ierr) + + ! Get the group of MPI_COMM_WORLD + call MPI_Comm_group(MPI_COMM_WORLD, group, ierr) + + ! Check group size + call MPI_Group_size(group, group_size, ierr) + + ! Verify result + error = .false. + if (group_size /= size) then + print *, "Rank ", rank, ": Error: Expected group size ", size, ", got ", group_size + error = .true. + else if (rank == 0) then + print *, "MPI_Comm_group test passed: group size = ", group_size + end if + + ! Free the group + call MPI_Group_free(group, ierr) + + ! Clean up + call MPI_Finalize(ierr) + + if (error) stop 1 +end program comm_group_1 \ No newline at end of file