Skip to content

Add new MPI_Op -> MPI_LOR and Add wrappers for Logical datatype for MPI_Allreduce #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ module mpi
real(8), parameter :: MPI_IN_PLACE = -1002
integer, parameter :: MPI_SUM = -2300
integer, parameter :: MPI_MAX = -2301
integer, parameter :: MPI_LOR = -2302
integer, parameter :: MPI_INFO_NULL = -2000
integer, parameter :: MPI_STATUS_SIZE = 5
integer :: MPI_STATUS_IGNORE = 0
Expand Down Expand Up @@ -99,6 +100,7 @@ module mpi
module procedure MPI_Allreduce_1D_recv_proc
module procedure MPI_Allreduce_1D_real_proc
module procedure MPI_Allreduce_1D_int_proc
module procedure MPI_Allreduce_scalar_logical_proc
end interface

interface MPI_Gatherv
Expand Down Expand Up @@ -168,14 +170,16 @@ module mpi
contains

integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op)
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max, c_mpi_lor
integer, intent(in) :: op_f
if (op_f == MPI_SUM) then
c_op = c_mpi_sum
else if (op_f == MPI_MAX) then
c_op = c_MPI_MAX
else if (op_f == MPI_LOR) then
c_op = c_mpi_lor
else
c_op = c_mpi_op_f2c(op_f)
c_op = c_mpi_op_f2c(op_f) ! For other operations, use the C binding
end if
end function

Expand Down Expand Up @@ -795,6 +799,35 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm
end if
end subroutine MPI_Allreduce_1D_int_proc

subroutine MPI_Allreduce_scalar_logical_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_comm_f2c
logical, intent(in), target :: sendbuf
logical, intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
integer, intent(out), optional :: ierror
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr
integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm
integer(c_int) :: local_ierr

sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_datatype = handle_mpi_datatype_f2c(datatype)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)

local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm)

if (present(ierror)) then
ierror = local_ierr
else
if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr
end if
end if
end subroutine MPI_Allreduce_scalar_logical_proc

function MPI_Wtime_proc() result(time)
use mpi_c_bindings, only: c_mpi_wtime
real(8) :: time
Expand Down
12 changes: 8 additions & 4 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ module mpi_c_bindings
type(c_ptr), bind(C, name="c_MPI_STATUSES_IGNORE") :: c_mpi_statuses_ignore
type(c_ptr), bind(C, name="c_MPI_IN_PLACE") :: c_mpi_in_place
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INFO_NULL") :: c_mpi_info_null

integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_DOUBLE") :: c_mpi_double
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_FLOAT") :: c_mpi_float
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character

integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOR") :: c_mpi_lor

integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null

interface

function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c")
Expand Down
2 changes: 2 additions & 0 deletions src/mpi_constants.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ MPI_Op c_MPI_SUM = MPI_SUM;

MPI_Op c_MPI_MAX = MPI_MAX;

MPI_Op c_MPI_LOR = MPI_LOR;

// Communicators Declarations

MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL;
Expand Down
26 changes: 26 additions & 0 deletions tests/allreduce_lor.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
program mre_mpi_lor_allreduce
use mpi
implicit none

integer :: ierr, rank, size
logical :: local_flag, global_flag

call MPI_INIT(ierr)
if (ierr /= MPI_SUCCESS) error stop "MPI_INIT failed"

call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr)
call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr)

! Initialize the local flag: True if this is the 0th rank, False otherwise
local_flag = (rank == 0)

! Perform logical OR reduction across all processes
call MPI_ALLREDUCE(local_flag, global_flag, 1, MPI_LOGICAL, MPI_LOR, MPI_COMM_WORLD, ierr)
if (global_flag .neqv. .true.) error stop "MPI_ALLREDUCE failed"

print *, 'Rank', rank, ': global_flag =', global_flag

call MPI_FINALIZE(ierr)
if (ierr /= MPI_SUCCESS) error stop "MPI_FINALIZE failed"

end program mre_mpi_lor_allreduce