From 9529857ca9a10045c588728ef1c4822c90e88119 Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Thu, 10 Apr 2025 14:00:30 +0530 Subject: [PATCH 1/3] use integer handles --- src/mpi.f90 | 82 ++++++++++++++++++++++-------------------- src/mpi_c_bindings.f90 | 81 ++++++++++++++++++++--------------------- 2 files changed, 85 insertions(+), 78 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 4532477..6d27b27 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -189,7 +189,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) integer, intent(out) :: size integer, optional, intent(out) :: ierror integer :: local_ierr - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm c_comm = c_mpi_comm_f2c(comm) local_ierr = c_mpi_comm_size(c_comm, size) @@ -210,7 +210,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm, c_datatype + integer(kind=8) :: c_comm, c_datatype integer :: local_ierr type(c_ptr) :: buffer_ptr @@ -236,7 +236,7 @@ subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm, c_datatype + integer(kind=8) :: c_comm, c_datatype integer :: local_ierr type(c_ptr) :: buffer_ptr @@ -263,9 +263,9 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r integer, intent(in) :: sendtype, recvtype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr - type(c_ptr) :: c_sendtype, c_recvtype + integer(kind=8) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr c_comm = c_mpi_comm_f2c(comm) @@ -293,9 +293,9 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, integer, intent(in) :: sendtype, recvtype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr - type(c_ptr) :: c_sendtype, c_recvtype + integer(kind=8) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr c_comm = c_mpi_comm_f2c(comm) @@ -324,7 +324,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) integer, intent(out) :: request integer, optional, intent(out) :: ierror type(c_ptr) :: buf_ptr - type(c_ptr) :: c_datatype, c_comm, c_request + integer(kind=8) :: c_datatype, c_comm, c_request integer(c_int) :: local_ierr buf_ptr = c_loc(buf) @@ -353,7 +353,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) integer, intent(out) :: request integer, optional, intent(out) :: ierror type(c_ptr) :: buf_ptr - type(c_ptr) :: c_datatype, c_comm, c_request + integer(kind=8) :: c_datatype, c_comm, c_request integer(c_int) :: local_ierr buf_ptr = c_loc(buf) @@ -381,10 +381,10 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr integer, intent(in) :: comm integer, intent(out) :: request integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr - type(c_ptr) :: c_datatype - type(c_ptr) :: c_request + integer(kind=8) :: c_datatype + integer(kind=8) :: c_request c_comm = c_mpi_comm_f2c(comm) c_datatype = c_mpi_datatype_f2c(datatype) @@ -407,7 +407,8 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier real(8), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr + integer(kind=8) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then @@ -438,7 +439,8 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr + integer(kind=8) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then @@ -469,7 +471,8 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com real(8), dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr + integer(kind=8) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr sendbuf_ptr = c_loc(sendbuf) @@ -496,7 +499,8 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm integer, dimension(:), intent(out), target :: recvbuf integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror - type(c_ptr) :: sendbuf_ptr, recvbuf_ptr, c_datatype, c_op, c_comm + type(c_ptr) :: sendbuf_ptr, recvbuf_ptr + integer(kind=8) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr sendbuf_ptr = c_loc(sendbuf) @@ -527,7 +531,7 @@ subroutine MPI_Barrier_proc(comm, ierror) use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out), optional :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr ! Convert Fortran handle to C handle @@ -549,7 +553,7 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) integer, intent(in) :: comm integer, intent(out) :: rank integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr c_comm = c_mpi_comm_f2c(comm) @@ -574,7 +578,7 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror integer, optional, intent(out) :: ierror integer(c_int) :: local_ierr - type(c_ptr) :: c_comm, c_info, c_new_comm + integer(kind=8) :: c_comm, c_info, c_new_comm ! Convert Fortran communicator and info handles to C pointers. c_comm = c_mpi_comm_f2c(comm) @@ -599,13 +603,14 @@ end subroutine MPI_Comm_split_type_proc subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f - real(8), dimension(*), intent(inout) :: buf + real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status(MPI_STATUS_SIZE) integer, optional, intent(out) :: ierror integer(c_int) :: local_ierr, status_ierr - type(c_ptr) :: c_dtype, c_comm, c_status + integer(kind=8) :: c_dtype, c_comm + type(c_ptr) :: c_status integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status ! Convert Fortran handles to C handles. @@ -616,7 +621,7 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st c_status = c_loc(tmp_status) ! Call the native MPI_Recv. - local_ierr = c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, c_status) + local_ierr = c_mpi_recv(c_loc(buf), count, c_dtype, source, tag, c_comm, c_status) ! Convert the C MPI_Status to Fortran status. if (local_ierr == MPI_SUCCESS) then @@ -634,36 +639,37 @@ end subroutine MPI_Recv_StatusArray_proc subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror) use iso_c_binding, only: c_int, c_ptr, c_loc use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f - real(8), dimension(*), intent(inout) :: buf + real(8), dimension(*), intent(inout), target :: buf integer, intent(in) :: count, source, tag, datatype, comm integer, intent(out) :: status integer, optional, intent(out) :: ierror integer(c_int) :: local_ierr, status_ierr - type(c_ptr) :: c_dtype, c_comm, c_status + integer(kind=8) :: c_dtype, c_comm + type(c_ptr) :: c_status integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status - + ! Convert Fortran handles to C handles. c_dtype = c_mpi_datatype_f2c(datatype) c_comm = c_mpi_comm_f2c(comm) - + ! Use a local temporary MPI_Status (as an array of c_int) c_status = c_loc(tmp_status) - + ! Call the native MPI_Recv. - local_ierr = c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, c_status) - + local_ierr = c_mpi_recv(c_loc(buf), count, c_dtype, source, tag, c_comm, c_status) + ! Convert the C MPI_Status to Fortran status. if (local_ierr == MPI_SUCCESS) then ! status_ierr = c_mpi_status_c2f(c_status, status) end if - + if (present(ierror)) then ierror = local_ierr else if (local_ierr /= MPI_SUCCESS) then print *, "MPI_Recv failed with error code: ", local_ierr end if - + end subroutine MPI_Recv_StatusIgnore_proc subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror) @@ -683,7 +689,7 @@ subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - type(c_ptr) :: c_datatype, c_comm + integer(kind=8) :: c_datatype, c_comm integer :: local_ierr c_datatype = c_mpi_datatype_f2c(datatype) @@ -700,8 +706,8 @@ subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_ca integer, intent(out) :: comm_cart integer, optional, intent(out) :: ierror integer(c_int) :: ndims_c, reorder_c, dims_c(ndims), periods_c(ndims) - type(c_ptr) :: c_comm_old - type(c_ptr) :: c_comm_cart + integer(kind=8) :: c_comm_old + integer(kind=8) :: c_comm_cart integer(c_int) :: local_ierr c_comm_old = c_mpi_comm_f2c(comm_old) @@ -736,7 +742,7 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) integer, intent(in) :: rank, maxdims integer, intent(out) :: coords(maxdims) integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr c_comm = c_mpi_comm_f2c(comm) @@ -758,7 +764,7 @@ subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ie integer, intent(in) :: direction, disp integer, intent(out) :: rank_source, rank_dest integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm + integer(kind=8) :: c_comm integer(c_int) :: local_ierr c_comm = c_mpi_comm_f2c(comm) @@ -799,7 +805,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) integer, intent(out) :: newcomm integer, optional, intent(out) :: ierror integer, target :: remain_dims_i(size(remain_dims)) - type(c_ptr) :: c_comm, c_newcomm + integer(kind=8) :: c_comm, c_newcomm integer :: local_ierr type(c_ptr) :: remain_dims_i_ptr @@ -832,7 +838,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co integer, intent(in) :: count, datatype, op, root, comm integer, optional, intent(out) :: ierror - type(c_ptr) :: c_comm, c_dtype, c_op + integer(kind=8) :: c_comm, c_dtype, c_op type(c_ptr) :: c_sendbuf, c_recvbuf integer(c_int) :: local_ierr diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 1fc55c2..1f46383 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -5,31 +5,31 @@ module mpi_c_bindings function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: comm_f - type(c_ptr) :: c_mpi_comm_f2c ! MPI_Comm as pointer + integer(kind=8) :: c_mpi_comm_f2c end function c_mpi_comm_f2c function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm_c + integer(kind=8), value :: comm_c integer(c_int) :: c_mpi_comm_c2f end function function c_mpi_request_c2f(request) bind(C, name="MPI_Request_c2f") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: request + integer(kind=8), value :: request integer(c_int) :: c_mpi_request_c2f end function function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: datatype - type(c_ptr) :: c_mpi_datatype_f2c + integer(kind=8) :: c_mpi_datatype_f2c end function c_mpi_datatype_f2c function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") use iso_c_binding, only: c_ptr, c_int integer(c_int), value :: op_f - type(c_ptr) :: c_mpi_op_f2c + integer(kind=8) :: c_mpi_op_f2c end function c_mpi_op_f2c function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") @@ -38,10 +38,11 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") integer(c_int) :: f_status(*) ! assumed-size array integer(c_int) :: c_mpi_status_c2f end function c_mpi_status_c2f + function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: info_f - type(c_ptr) :: c_mpi_info_f2c + integer(kind=8) :: c_mpi_info_f2c end function c_mpi_info_f2c function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran") @@ -75,7 +76,7 @@ end function c_mpi_finalize function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm + integer(kind=8), value :: comm integer(c_int), intent(out) :: size integer(c_int) :: c_mpi_comm_size end function c_mpi_comm_size @@ -84,9 +85,9 @@ function c_mpi_bcast(buffer, count, datatype, root, comm) bind(C, name="MPI_Bcas use iso_c_binding, only : c_ptr, c_int type(c_ptr), value :: buffer integer(c_int), value :: count - type(c_ptr), value :: datatype + integer(kind=8), value :: datatype integer(c_int), value :: root - type(c_ptr), value :: comm + integer(kind=8), value :: comm integer(c_int) :: c_mpi_bcast end function c_mpi_bcast @@ -96,8 +97,8 @@ function c_mpi_allgather_int(sendbuf, sendcount, sendtype, recvbuf, & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: sendcount, recvcount - type(c_ptr), value :: sendtype, recvtype - type(c_ptr), value :: comm + integer(kind=8), value :: sendtype, recvtype + integer(kind=8), value :: comm integer(c_int) :: c_mpi_allgather_int end function @@ -107,8 +108,8 @@ function c_mpi_allgather_real(sendbuf, sendcount, sendtype, recvbuf, & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: sendcount, recvcount - type(c_ptr), value :: sendtype, recvtype - type(c_ptr), value :: comm + integer(kind=8), value :: sendtype, recvtype + integer(kind=8), value :: comm integer(c_int) :: c_mpi_allgather_real end function @@ -116,9 +117,9 @@ function c_mpi_isend(buf, count, datatype, dest, tag, comm, request) bind(C, nam use iso_c_binding, only: c_int, c_double, c_ptr type(c_ptr), value :: buf integer(c_int), value :: count, dest, tag - type(c_ptr), value :: datatype - type(c_ptr), value :: comm - type(c_ptr), intent(out) :: request + integer(kind=8), value :: datatype + integer(kind=8), value :: comm + integer(kind=8), intent(out) :: request integer(c_int) :: c_mpi_isend end function @@ -126,9 +127,9 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n use iso_c_binding, only: c_int, c_double, c_ptr real(c_double), dimension(*), intent(out) :: buf integer(c_int), value :: count, source, tag - type(c_ptr), value :: datatype - type(c_ptr), value :: comm - type(c_ptr), intent(out) :: request + integer(kind=8), value :: datatype + integer(kind=8), value :: comm + integer(kind=8), intent(out) :: request integer(c_int) :: c_mpi_irecv end function @@ -138,7 +139,7 @@ function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: count - type(c_ptr), value :: datatype, op, comm + integer(kind=8), value :: datatype, op, comm integer(c_int) :: c_mpi_allreduce end function @@ -149,36 +150,36 @@ function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime") function c_mpi_barrier(comm) bind(C, name="MPI_Barrier") use iso_c_binding, only: c_ptr, c_int - type(c_ptr), value :: comm ! MPI_Comm as pointer + integer(kind=8), value :: comm ! MPI_Comm as pointer integer(c_int) :: c_mpi_barrier end function c_mpi_barrier function c_mpi_comm_rank(comm, rank) bind(C, name="MPI_Comm_rank") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm + integer(kind=8), value :: comm integer(c_int), intent(out) :: rank integer(c_int) :: c_mpi_comm_rank end function c_mpi_comm_rank function c_mpi_comm_split_type(c_comm, split_type, key, c_info, new_comm) bind(C, name="MPI_Comm_split_type") use iso_c_binding, only: c_ptr, c_int - type(c_ptr), value :: c_comm + integer(kind=8), value :: c_comm integer(c_int), value :: split_type integer(c_int), value :: key - type(c_ptr), value :: c_info - type(c_ptr) :: new_comm + integer(kind=8), value :: c_info + integer(kind=8) :: new_comm integer(c_int) :: c_mpi_comm_split_type end function c_mpi_comm_split_type function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, name="MPI_Recv") use iso_c_binding, only: c_ptr, c_int, c_double - real(c_double), dimension(*), intent(out) :: buf + type(c_ptr), value :: buf integer(c_int), value :: count - type(c_ptr), value :: c_dtype + integer(kind=8), value :: c_dtype integer(c_int), value :: source integer(c_int), value :: tag - type(c_ptr), value :: c_comm - type(c_ptr) :: status + integer(kind=8), value :: c_comm + type(c_ptr), value :: status integer(c_int) :: c_mpi_recv end function c_mpi_recv @@ -194,23 +195,23 @@ function c_mpi_ssend(buf, count, datatype, dest, tag, comm) bind(C, name="MPI_Ss use iso_c_binding, only: c_int, c_double, c_ptr real(c_double), dimension(*), intent(in) :: buf integer(c_int), value :: count, dest, tag - type(c_ptr), value :: datatype - type(c_ptr), value :: comm + integer(kind=8), value :: datatype + integer(kind=8), value :: comm integer(c_int) :: c_mpi_ssend end function function c_mpi_cart_create(comm_old, ndims, dims, periods, reorder, comm_cart) bind(C, name="MPI_Cart_create") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm_old + integer(kind=8), value :: comm_old integer(c_int), value :: ndims, reorder integer(c_int), intent(in) :: dims(*), periods(*) - type(c_ptr), intent(out) :: comm_cart + integer(kind=8), intent(out) :: comm_cart integer(c_int) :: c_mpi_cart_create end function function c_mpi_cart_coords(comm, rank, maxdims, coords) bind(C, name="MPI_Cart_coords") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm + integer(kind=8), value :: comm integer(c_int), value :: rank, maxdims integer(c_int), intent(out) :: coords(*) integer(c_int) :: c_mpi_cart_coords @@ -218,7 +219,7 @@ function c_mpi_cart_coords(comm, rank, maxdims, coords) bind(C, name="MPI_Cart_c function c_mpi_cart_shift(comm, direction, disp, rank_source, rank_dest) bind(C, name="MPI_Cart_shift") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm + integer(kind=8), value :: comm integer(c_int), value :: direction, disp integer(c_int), intent(out) :: rank_source, rank_dest integer(c_int) :: c_mpi_cart_shift @@ -233,9 +234,9 @@ function c_mpi_dims_create(nnodes, ndims, dims) bind(C, name="MPI_Dims_create") function c_mpi_cart_sub(comm, remain_dims, newcomm) bind(C, name ="MPI_Cart_sub") use iso_c_binding, only: c_int, c_ptr - type(c_ptr), value :: comm + integer(kind=8), value :: comm type(c_ptr), value :: remain_dims - type(c_ptr), intent(out) :: newcomm + integer(kind=8), intent(out) :: newcomm integer(c_int) :: c_mpi_cart_sub end function @@ -246,10 +247,10 @@ function c_mpi_reduce(sendbuf, recvbuf, count, c_dtype, c_op, root, c_comm) & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: count - type(c_ptr), value :: c_dtype - type(c_ptr), value :: c_op + integer(kind=8), value :: c_dtype + integer(kind=8), value :: c_op integer(c_int), value :: root - type(c_ptr), value :: c_comm + integer(kind=8), value :: c_comm integer(c_int) :: c_mpi_reduce end function c_mpi_reduce From 168f2a959db4f132dae5d256daf949b64680eeaf Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Thu, 10 Apr 2025 14:46:57 +0530 Subject: [PATCH 2/3] use MPI_HANDLE_KIND as preprocessor directive for Open MPI and MPICH --- .github/workflows/CI.yml | 32 ++++++++-------- src/mpi.f90 | 63 ++++++++++++++++-------------- src/mpi_c_bindings.f90 | 82 +++++++++++++++++++++------------------- tests/run_tests.sh | 10 ++++- 4 files changed, 103 insertions(+), 84 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a271887..320b514 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -73,8 +73,8 @@ jobs: shell: bash -e -x -l {0} run: | cd tests - FC="gfortran" ./run_tests.sh - FC="gfortran -O3 -march=native" ./run_tests.sh + FC="gfortran -cpp" ./run_tests.sh + FC="gfortran -O3 -march=native -cpp" ./run_tests.sh Run_standalone_tests_with_GFortran_with_MPICH: name: "Run standalone tests with GFortran with MPICH" @@ -95,8 +95,8 @@ jobs: shell: bash -e -x -l {0} run: | cd tests - FC="gfortran" ./run_tests.sh - FC="gfortran -O3 -march=native" ./run_tests.sh + FC="gfortran -cpp" ./run_tests.sh + FC="gfortran -O3 -march=native -cpp" ./run_tests.sh Run_standalone_tests_with_LFortran_with_OpenMPI: name: "Run standalone tests with LFortran with Open MPI" @@ -117,8 +117,8 @@ jobs: shell: bash -e -x -l {0} run: | cd tests - FC="lfortran" ./run_tests.sh - FC="lfortran --fast" ./run_tests.sh + FC="lfortran --cpp" ./run_tests.sh + FC="lfortran --fast --cpp" ./run_tests.sh Run_standalone_tests_with_LFortran_with_MPICH: name: "Run standalone tests with LFortran with MPICH" @@ -139,8 +139,8 @@ jobs: shell: bash -e -x -l {0} run: | cd tests - FC="lfortran" ./run_tests.sh - FC="lfortran --fast" ./run_tests.sh + FC="lfortran --cpp" ./run_tests.sh + FC="lfortran --fast --cpp" ./run_tests.sh Compile_POT3D_with_GFortran_with_OpenMPI: name: "Build POT3D and validate with GFortran with Open MPI" @@ -162,14 +162,14 @@ jobs: shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="gfortran -O3 -march=native" ./build_and_run_gfortran.sh + FC="gfortran -O3 -march=native -cpp" ./build_and_run_gfortran.sh # build and validation without GFortran's optimization - name: POT3D Build and validation with GFortran without optimization using Open MPI (MPI only) shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="gfortran" ./build_and_run_gfortran.sh + FC="gfortran -cpp" ./build_and_run_gfortran.sh Compile_POT3D_with_LFortran_with_OpenMPI: name: "Build POT3D and validate with LFortran with Open MPI" @@ -191,14 +191,14 @@ jobs: shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="lfortran --fast" ./build_and_run_lfortran.sh + FC="lfortran --fast --cpp" ./build_and_run_lfortran.sh # build and validation without LFortran's optimization - name: POT3D Build and validation with LFortran without optimization using Open MPI (MPI only) shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="lfortran" ./build_and_run_lfortran.sh + FC="lfortran --cpp" ./build_and_run_lfortran.sh Compile_POT3D_with_GFortran_with_MPICH: name: "Build POT3D and validate with GFortran with MPICH" @@ -220,14 +220,14 @@ jobs: shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="gfortran -O3 -march=native" ./build_and_run_gfortran.sh + FC="gfortran -O3 -march=native -cpp" ./build_and_run_gfortran.sh # build and validation without GFortran's optimization - name: POT3D Build and validation with GFortran without optimization (MPI only) shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="gfortran" ./build_and_run_gfortran.sh + FC="gfortran -cpp" ./build_and_run_gfortran.sh Compile_POT3D_with_LFortran_with_MPICH: name: "Build POT3D and validate with LFortran with MPICH" @@ -249,11 +249,11 @@ jobs: shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="lfortran --fast" ./build_and_run_lfortran.sh + FC="lfortran --fast --cpp" ./build_and_run_lfortran.sh # build and validation without LFortran's optimization - name: POT3D Build and validation with LFortran without optimization using MPICH (MPI only) shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="lfortran" ./build_and_run_lfortran.sh + FC="lfortran --cpp" ./build_and_run_lfortran.sh diff --git a/src/mpi.f90 b/src/mpi.f90 index 6d27b27..7fdcf54 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -1,5 +1,12 @@ module mpi implicit none + +#ifdef OPEN_MPI +#define MPI_HANDLE_KIND 8 +#else +#define MPI_HANDLE_KIND 4 +#endif + integer, parameter :: MPI_THREAD_FUNNELED = 1 integer, parameter :: MPI_INTEGER = -10002 @@ -189,7 +196,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror) integer, intent(out) :: size integer, optional, intent(out) :: ierror integer :: local_ierr - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm c_comm = c_mpi_comm_f2c(comm) local_ierr = c_mpi_comm_size(c_comm, size) @@ -210,7 +217,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror) integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm, c_datatype + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_datatype integer :: local_ierr type(c_ptr) :: buffer_ptr @@ -236,7 +243,7 @@ subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror) integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm, c_datatype + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_datatype integer :: local_ierr type(c_ptr) :: buffer_ptr @@ -263,9 +270,9 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r integer, intent(in) :: sendtype, recvtype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - integer(kind=8) :: c_sendtype, c_recvtype + integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr c_comm = c_mpi_comm_f2c(comm) @@ -293,9 +300,9 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, integer, intent(in) :: sendtype, recvtype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - integer(kind=8) :: c_sendtype, c_recvtype + integer(kind=MPI_HANDLE_KIND) :: c_sendtype, c_recvtype type(c_ptr) :: sendbuf_ptr, recvbuf_ptr c_comm = c_mpi_comm_f2c(comm) @@ -324,7 +331,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror) integer, intent(out) :: request integer, optional, intent(out) :: ierror type(c_ptr) :: buf_ptr - integer(kind=8) :: c_datatype, c_comm, c_request + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_comm, c_request integer(c_int) :: local_ierr buf_ptr = c_loc(buf) @@ -353,7 +360,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror) integer, intent(out) :: request integer, optional, intent(out) :: ierror type(c_ptr) :: buf_ptr - integer(kind=8) :: c_datatype, c_comm, c_request + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_comm, c_request integer(c_int) :: local_ierr buf_ptr = c_loc(buf) @@ -381,10 +388,10 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr integer, intent(in) :: comm integer, intent(out) :: request integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr - integer(kind=8) :: c_datatype - integer(kind=8) :: c_request + integer(kind=MPI_HANDLE_KIND) :: c_datatype + integer(kind=MPI_HANDLE_KIND) :: c_request c_comm = c_mpi_comm_f2c(comm) c_datatype = c_mpi_datatype_f2c(datatype) @@ -408,7 +415,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - integer(kind=8) :: c_datatype, c_op, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then @@ -440,7 +447,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - integer(kind=8) :: c_datatype, c_op, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr if (sendbuf == MPI_IN_PLACE) then @@ -472,7 +479,7 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - integer(kind=8) :: c_datatype, c_op, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr sendbuf_ptr = c_loc(sendbuf) @@ -500,7 +507,7 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm integer, intent(in) :: count, datatype, op, comm integer, intent(out), optional :: ierror type(c_ptr) :: sendbuf_ptr, recvbuf_ptr - integer(kind=8) :: c_datatype, c_op, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm integer(c_int) :: local_ierr sendbuf_ptr = c_loc(sendbuf) @@ -531,7 +538,7 @@ subroutine MPI_Barrier_proc(comm, ierror) use iso_c_binding, only: c_int, c_ptr integer, intent(in) :: comm integer, intent(out), optional :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr ! Convert Fortran handle to C handle @@ -553,7 +560,7 @@ subroutine MPI_Comm_rank_proc(comm, rank, ierror) integer, intent(in) :: comm integer, intent(out) :: rank integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr c_comm = c_mpi_comm_f2c(comm) @@ -578,7 +585,7 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror integer, optional, intent(out) :: ierror integer(c_int) :: local_ierr - integer(kind=8) :: c_comm, c_info, c_new_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_info, c_new_comm ! Convert Fortran communicator and info handles to C pointers. c_comm = c_mpi_comm_f2c(comm) @@ -609,7 +616,7 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st integer, optional, intent(out) :: ierror integer(c_int) :: local_ierr, status_ierr - integer(kind=8) :: c_dtype, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_dtype, c_comm type(c_ptr) :: c_status integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status @@ -645,7 +652,7 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s integer, optional, intent(out) :: ierror integer(c_int) :: local_ierr, status_ierr - integer(kind=8) :: c_dtype, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_dtype, c_comm type(c_ptr) :: c_status integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status @@ -689,7 +696,7 @@ subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror) integer, intent(in) :: datatype integer, intent(in) :: comm integer, optional, intent(out) :: ierror - integer(kind=8) :: c_datatype, c_comm + integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_comm integer :: local_ierr c_datatype = c_mpi_datatype_f2c(datatype) @@ -706,8 +713,8 @@ subroutine MPI_Cart_create_proc(comm_old, ndims, dims, periods, reorder, comm_ca integer, intent(out) :: comm_cart integer, optional, intent(out) :: ierror integer(c_int) :: ndims_c, reorder_c, dims_c(ndims), periods_c(ndims) - integer(kind=8) :: c_comm_old - integer(kind=8) :: c_comm_cart + integer(kind=MPI_HANDLE_KIND) :: c_comm_old + integer(kind=MPI_HANDLE_KIND) :: c_comm_cart integer(c_int) :: local_ierr c_comm_old = c_mpi_comm_f2c(comm_old) @@ -742,7 +749,7 @@ subroutine MPI_Cart_coords_proc(comm, rank, maxdims, coords, ierror) integer, intent(in) :: rank, maxdims integer, intent(out) :: coords(maxdims) integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr c_comm = c_mpi_comm_f2c(comm) @@ -764,7 +771,7 @@ subroutine MPI_Cart_shift_proc(comm, direction, disp, rank_source, rank_dest, ie integer, intent(in) :: direction, disp integer, intent(out) :: rank_source, rank_dest integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm + integer(kind=MPI_HANDLE_KIND) :: c_comm integer(c_int) :: local_ierr c_comm = c_mpi_comm_f2c(comm) @@ -805,7 +812,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror) integer, intent(out) :: newcomm integer, optional, intent(out) :: ierror integer, target :: remain_dims_i(size(remain_dims)) - integer(kind=8) :: c_comm, c_newcomm + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_newcomm integer :: local_ierr type(c_ptr) :: remain_dims_i_ptr @@ -838,7 +845,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co integer, intent(in) :: count, datatype, op, root, comm integer, optional, intent(out) :: ierror - integer(kind=8) :: c_comm, c_dtype, c_op + integer(kind=MPI_HANDLE_KIND) :: c_comm, c_dtype, c_op type(c_ptr) :: c_sendbuf, c_recvbuf integer(c_int) :: local_ierr diff --git a/src/mpi_c_bindings.f90 b/src/mpi_c_bindings.f90 index 1f46383..c81e85f 100644 --- a/src/mpi_c_bindings.f90 +++ b/src/mpi_c_bindings.f90 @@ -1,35 +1,41 @@ module mpi_c_bindings implicit none +#ifdef OPEN_MPI +#define MPI_HANDLE_KIND 8 +#else +#define MPI_HANDLE_KIND 4 +#endif + interface function c_mpi_comm_f2c(comm_f) bind(C, name="get_c_comm_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: comm_f - integer(kind=8) :: c_mpi_comm_f2c + integer(kind=MPI_HANDLE_KIND) :: c_mpi_comm_f2c end function c_mpi_comm_f2c function c_mpi_comm_c2f(comm_c) bind(C, name="MPI_Comm_c2f") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm_c + integer(kind=MPI_HANDLE_KIND), value :: comm_c integer(c_int) :: c_mpi_comm_c2f end function function c_mpi_request_c2f(request) bind(C, name="MPI_Request_c2f") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: request + integer(kind=MPI_HANDLE_KIND), value :: request integer(c_int) :: c_mpi_request_c2f end function function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: datatype - integer(kind=8) :: c_mpi_datatype_f2c + integer(kind=MPI_HANDLE_KIND) :: c_mpi_datatype_f2c end function c_mpi_datatype_f2c function c_mpi_op_f2c(op_f) bind(C, name="get_c_op_from_fortran") use iso_c_binding, only: c_ptr, c_int integer(c_int), value :: op_f - integer(kind=8) :: c_mpi_op_f2c + integer(kind=MPI_HANDLE_KIND) :: c_mpi_op_f2c end function c_mpi_op_f2c function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f") @@ -42,7 +48,7 @@ end function c_mpi_status_c2f function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran") use iso_c_binding, only: c_int, c_ptr integer(c_int), value :: info_f - integer(kind=8) :: c_mpi_info_f2c + integer(kind=MPI_HANDLE_KIND) :: c_mpi_info_f2c end function c_mpi_info_f2c function c_mpi_in_place_f2c(in_place_f) bind(C,name="get_c_mpi_inplace_from_fortran") @@ -76,7 +82,7 @@ end function c_mpi_finalize function c_mpi_comm_size(comm, size) bind(C, name="MPI_Comm_size") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int), intent(out) :: size integer(c_int) :: c_mpi_comm_size end function c_mpi_comm_size @@ -85,9 +91,9 @@ function c_mpi_bcast(buffer, count, datatype, root, comm) bind(C, name="MPI_Bcas use iso_c_binding, only : c_ptr, c_int type(c_ptr), value :: buffer integer(c_int), value :: count - integer(kind=8), value :: datatype + integer(kind=MPI_HANDLE_KIND), value :: datatype integer(c_int), value :: root - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int) :: c_mpi_bcast end function c_mpi_bcast @@ -97,8 +103,8 @@ function c_mpi_allgather_int(sendbuf, sendcount, sendtype, recvbuf, & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: sendcount, recvcount - integer(kind=8), value :: sendtype, recvtype - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: sendtype, recvtype + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int) :: c_mpi_allgather_int end function @@ -108,8 +114,8 @@ function c_mpi_allgather_real(sendbuf, sendcount, sendtype, recvbuf, & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: sendcount, recvcount - integer(kind=8), value :: sendtype, recvtype - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: sendtype, recvtype + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int) :: c_mpi_allgather_real end function @@ -117,9 +123,9 @@ function c_mpi_isend(buf, count, datatype, dest, tag, comm, request) bind(C, nam use iso_c_binding, only: c_int, c_double, c_ptr type(c_ptr), value :: buf integer(c_int), value :: count, dest, tag - integer(kind=8), value :: datatype - integer(kind=8), value :: comm - integer(kind=8), intent(out) :: request + integer(kind=MPI_HANDLE_KIND), value :: datatype + integer(kind=MPI_HANDLE_KIND), value :: comm + integer(kind=MPI_HANDLE_KIND), intent(out) :: request integer(c_int) :: c_mpi_isend end function @@ -127,9 +133,9 @@ function c_mpi_irecv(buf, count, datatype, source, tag, comm, request) bind(C, n use iso_c_binding, only: c_int, c_double, c_ptr real(c_double), dimension(*), intent(out) :: buf integer(c_int), value :: count, source, tag - integer(kind=8), value :: datatype - integer(kind=8), value :: comm - integer(kind=8), intent(out) :: request + integer(kind=MPI_HANDLE_KIND), value :: datatype + integer(kind=MPI_HANDLE_KIND), value :: comm + integer(kind=MPI_HANDLE_KIND), intent(out) :: request integer(c_int) :: c_mpi_irecv end function @@ -139,7 +145,7 @@ function c_mpi_allreduce(sendbuf, recvbuf, count, datatype, op, comm) & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: count - integer(kind=8), value :: datatype, op, comm + integer(kind=MPI_HANDLE_KIND), value :: datatype, op, comm integer(c_int) :: c_mpi_allreduce end function @@ -150,24 +156,24 @@ function c_mpi_wtime() result(time) bind(C, name="MPI_Wtime") function c_mpi_barrier(comm) bind(C, name="MPI_Barrier") use iso_c_binding, only: c_ptr, c_int - integer(kind=8), value :: comm ! MPI_Comm as pointer + integer(kind=MPI_HANDLE_KIND), value :: comm ! MPI_Comm as pointer integer(c_int) :: c_mpi_barrier end function c_mpi_barrier function c_mpi_comm_rank(comm, rank) bind(C, name="MPI_Comm_rank") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int), intent(out) :: rank integer(c_int) :: c_mpi_comm_rank end function c_mpi_comm_rank function c_mpi_comm_split_type(c_comm, split_type, key, c_info, new_comm) bind(C, name="MPI_Comm_split_type") use iso_c_binding, only: c_ptr, c_int - integer(kind=8), value :: c_comm + integer(kind=MPI_HANDLE_KIND), value :: c_comm integer(c_int), value :: split_type integer(c_int), value :: key - integer(kind=8), value :: c_info - integer(kind=8) :: new_comm + integer(kind=MPI_HANDLE_KIND), value :: c_info + integer(kind=MPI_HANDLE_KIND) :: new_comm integer(c_int) :: c_mpi_comm_split_type end function c_mpi_comm_split_type @@ -175,10 +181,10 @@ function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, na use iso_c_binding, only: c_ptr, c_int, c_double type(c_ptr), value :: buf integer(c_int), value :: count - integer(kind=8), value :: c_dtype + integer(kind=MPI_HANDLE_KIND), value :: c_dtype integer(c_int), value :: source integer(c_int), value :: tag - integer(kind=8), value :: c_comm + integer(kind=MPI_HANDLE_KIND), value :: c_comm type(c_ptr), value :: status integer(c_int) :: c_mpi_recv end function c_mpi_recv @@ -195,23 +201,23 @@ function c_mpi_ssend(buf, count, datatype, dest, tag, comm) bind(C, name="MPI_Ss use iso_c_binding, only: c_int, c_double, c_ptr real(c_double), dimension(*), intent(in) :: buf integer(c_int), value :: count, dest, tag - integer(kind=8), value :: datatype - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: datatype + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int) :: c_mpi_ssend end function function c_mpi_cart_create(comm_old, ndims, dims, periods, reorder, comm_cart) bind(C, name="MPI_Cart_create") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm_old + integer(kind=MPI_HANDLE_KIND), value :: comm_old integer(c_int), value :: ndims, reorder integer(c_int), intent(in) :: dims(*), periods(*) - integer(kind=8), intent(out) :: comm_cart + integer(kind=MPI_HANDLE_KIND), intent(out) :: comm_cart integer(c_int) :: c_mpi_cart_create end function function c_mpi_cart_coords(comm, rank, maxdims, coords) bind(C, name="MPI_Cart_coords") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int), value :: rank, maxdims integer(c_int), intent(out) :: coords(*) integer(c_int) :: c_mpi_cart_coords @@ -219,7 +225,7 @@ function c_mpi_cart_coords(comm, rank, maxdims, coords) bind(C, name="MPI_Cart_c function c_mpi_cart_shift(comm, direction, disp, rank_source, rank_dest) bind(C, name="MPI_Cart_shift") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: comm integer(c_int), value :: direction, disp integer(c_int), intent(out) :: rank_source, rank_dest integer(c_int) :: c_mpi_cart_shift @@ -234,9 +240,9 @@ function c_mpi_dims_create(nnodes, ndims, dims) bind(C, name="MPI_Dims_create") function c_mpi_cart_sub(comm, remain_dims, newcomm) bind(C, name ="MPI_Cart_sub") use iso_c_binding, only: c_int, c_ptr - integer(kind=8), value :: comm + integer(kind=MPI_HANDLE_KIND), value :: comm type(c_ptr), value :: remain_dims - integer(kind=8), intent(out) :: newcomm + integer(kind=MPI_HANDLE_KIND), intent(out) :: newcomm integer(c_int) :: c_mpi_cart_sub end function @@ -247,10 +253,10 @@ function c_mpi_reduce(sendbuf, recvbuf, count, c_dtype, c_op, root, c_comm) & type(c_ptr), value :: sendbuf type(c_ptr), value :: recvbuf integer(c_int), value :: count - integer(kind=8), value :: c_dtype - integer(kind=8), value :: c_op + integer(kind=MPI_HANDLE_KIND), value :: c_dtype + integer(kind=MPI_HANDLE_KIND), value :: c_op integer(c_int), value :: root - integer(kind=8), value :: c_comm + integer(kind=MPI_HANDLE_KIND), value :: c_comm integer(c_int) :: c_mpi_reduce end function c_mpi_reduce diff --git a/tests/run_tests.sh b/tests/run_tests.sh index caf5d07..7e7feec 100755 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -48,6 +48,12 @@ else fi fi +# Define preprocessor flags based on MPI type +FC_FLAGS="" +if [[ "$MPI_TYPE" == "openmpi" ]]; then + FC_FLAGS="-DOPEN_MPI" +fi + echo -e "${RED}Removing all untracked files${NC}" git clean -dfx echo -e "#################################" @@ -70,8 +76,8 @@ fi if [ $USE_WRAPPERS -eq 1 ]; then $CC -I"$CONDA_PREFIX/include" -c ../src/mpi_wrapper.c - $FC -c ../src/mpi_c_bindings.f90 - $FC -c ../src/mpi.f90 + $FC $FC_FLAGS -c ../src/mpi_c_bindings.f90 + $FC $FC_FLAGS -c ../src/mpi.f90 fi start_time=$(date +%s) From ebcd827b492b5ddb5c6af6af41578168846686f5 Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Thu, 10 Apr 2025 14:59:39 +0530 Subject: [PATCH 3/3] use -DOPEN_MPI=yes when running with Open MPI --- .github/workflows/CI.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 320b514..30bac03 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -73,8 +73,8 @@ jobs: shell: bash -e -x -l {0} run: | cd tests - FC="gfortran -cpp" ./run_tests.sh - FC="gfortran -O3 -march=native -cpp" ./run_tests.sh + FC="gfortran -cpp -DOPEN_MPI=yes" ./run_tests.sh + FC="gfortran -O3 -march=native -cpp -DOPEN_MPI=yes" ./run_tests.sh Run_standalone_tests_with_GFortran_with_MPICH: name: "Run standalone tests with GFortran with MPICH" @@ -117,8 +117,8 @@ jobs: shell: bash -e -x -l {0} run: | cd tests - FC="lfortran --cpp" ./run_tests.sh - FC="lfortran --fast --cpp" ./run_tests.sh + FC="lfortran --cpp -DOPEN_MPI=yes" ./run_tests.sh + FC="lfortran --fast --cpp -DOPEN_MPI=yes" ./run_tests.sh Run_standalone_tests_with_LFortran_with_MPICH: name: "Run standalone tests with LFortran with MPICH" @@ -162,14 +162,14 @@ jobs: shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="gfortran -O3 -march=native -cpp" ./build_and_run_gfortran.sh + FC="gfortran -O3 -march=native -cpp -DOPEN_MPI=yes" ./build_and_run_gfortran.sh # build and validation without GFortran's optimization - name: POT3D Build and validation with GFortran without optimization using Open MPI (MPI only) shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="gfortran -cpp" ./build_and_run_gfortran.sh + FC="gfortran -cpp -DOPEN_MPI=yes" ./build_and_run_gfortran.sh Compile_POT3D_with_LFortran_with_OpenMPI: name: "Build POT3D and validate with LFortran with Open MPI" @@ -191,14 +191,14 @@ jobs: shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="lfortran --fast --cpp" ./build_and_run_lfortran.sh + FC="lfortran --fast --cpp -DOPEN_MPI=yes" ./build_and_run_lfortran.sh # build and validation without LFortran's optimization - name: POT3D Build and validation with LFortran without optimization using Open MPI (MPI only) shell: bash -e -x -l {0} run: | cd tests/pot3d - FC="lfortran --cpp" ./build_and_run_lfortran.sh + FC="lfortran --cpp -DOPEN_MPI=yes" ./build_and_run_lfortran.sh Compile_POT3D_with_GFortran_with_MPICH: name: "Build POT3D and validate with GFortran with MPICH"