Skip to content

Remove C-wrapper for MPI_Waitall #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ subroutine MPI_Comm_split_type_proc(comm, split_type, key, info, newcomm, ierror
end if
end if

end subroutine MPI_Comm_split_type_proc
end subroutine MPI_Comm_split_type_proc

subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
Expand Down Expand Up @@ -636,14 +636,14 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st
end if

if (present(ierror)) then
ierror = local_ierr
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Recv failed with error code: ", local_ierr
print *, "MPI_Recv failed with error code: ", local_ierr
end if

end subroutine MPI_Recv_StatusArray_proc
end subroutine MPI_Recv_StatusArray_proc

subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror)
subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_recv, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_status_c2f
real(8), dimension(*), intent(inout), target :: buf
Expand Down Expand Up @@ -672,21 +672,64 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s
end if

if (present(ierror)) then
ierror = local_ierr
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Recv failed with error code: ", local_ierr
print *, "MPI_Recv failed with error code: ", local_ierr
end if

end subroutine MPI_Recv_StatusIgnore_proc
end subroutine MPI_Recv_StatusIgnore_proc

subroutine MPI_Waitall_proc(count, array_of_requests, array_of_statuses, ierror)
use mpi_c_bindings, only: c_mpi_waitall
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_waitall, c_mpi_request_f2c, c_mpi_request_c2f, c_mpi_status_c2f
integer, intent(in) :: count
integer, intent(inout) :: array_of_requests(count)
integer, intent(out) :: array_of_statuses(*)
integer, dimension(count), intent(inout) :: array_of_requests
integer, dimension(*), intent(out) :: array_of_statuses
integer, optional, intent(out) :: ierror
call c_mpi_waitall(count, array_of_requests, array_of_statuses, ierror)
end subroutine

integer(c_int) :: local_ierr, status_ierr
integer :: i

! Allocate temporary arrays for the C representations.
integer(kind=MPI_HANDLE_KIND), allocatable :: c_requests(:)
type(c_ptr), allocatable :: c_statuses(:)
allocate(c_requests(count))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lfortran treats arrays of type(c_ptr) as scalar

No specific test file given. Will compile/run all .f90 tests...
semantic error: Array reference is not allowed on scalar variable
   --> ../src/mpi.f90:688:13
    |
688 |             c_requests(i) = c_mpi_request_f2c(array_of_requests(i))
    |             ^^^^^^^^^^^^^ 


Note: Please report unclear, confusing or incorrect messages as bugs at
https://github.com/lfortran/lfortran/issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just make it a non-allocatable by doing type(c_ptr) :: c_requests(count), would that be ok?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope it too gives error

 aditya-trivedi   tests    wait_2 ≢  ~1    gfortran -c test.f90 
 aditya-trivedi   tests    wait_2 ≢  ?1 ~1    lfortran -c test.f90 
semantic error: Array reference is not allowed on scalar variable
 --> test.f90:6:5
  |
6 |     a(1) = c_loc(i)
  |     ^^^^ 


Note: Please report unclear, confusing or incorrect messages as bugs at
https://github.com/lfortran/lfortran/issues.
 aditya-trivedi   tests    wait_2 ≢  ?1 ~1    cat test.f90 
program main
    use iso_c_binding, only: c_ptr, c_loc
    type(c_ptr) :: a(10)
    integer ,target:: i=1
    ! allocate(a(10))
    a(1) = c_loc(i)
end program

Here is small mre which i tried

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. We've an error with GFortran + MPICH as well for both standalone test and POT3D as well, so there is something wrong with the Fortran code either in mpi.f90 or mpi_c_bindings.f90, I think we should try to fix that first.

Then we would've a better understanding of whether something is actually needed to be fixed in LFortran or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem we discovered in LFortran, probably is fixed with the PR: lfortran/lfortran#6839.

Copy link
Collaborator

@gxyd gxyd Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though, I still don't know exactly why this PR (to remove the C wrapper) doesn't work with MPICH but works with Open MPI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please fix LFortran for all bugs. That's why we are doing it also. :)

That PR is merged, so this might fix it --- can you try locally?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still get an error, I'm extracting the MRE for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some other error albeit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is an LFortran error, then extract it and fix it. That's a very valuable thing to do.

allocate(c_statuses(count*MPI_STATUS_SIZE))

! Convert Fortran requests to C requests.
do i = 1, count
c_requests(i) = c_mpi_request_f2c(array_of_requests(i))
end do

! Call the native MPI_Waitall.
local_ierr = c_mpi_waitall(count, c_requests, c_statuses)

! Convert the C requests back to Fortran handles.
do i = 1, count
array_of_requests(i) = c_mpi_request_c2f(c_requests(i))
end do

! For each status, convert the C status to Fortran status.
if(array_of_statuses(1) == MPI_STATUS_IGNORE) then
! If the status is ignored, we don't need to convert.
array_of_statuses(1) = 0
! print *, "Status is ignored, no conversion needed."
else
! Convert the C status to Fortran status.
do i = 1, count
status_ierr = c_mpi_status_c2f(c_statuses(i), array_of_statuses((i-1)*MPI_STATUS_SIZE+1))
end do
end if

if (present(ierror)) then
ierror = local_ierr
else if (local_ierr /= MPI_SUCCESS) then
print *, "MPI_Waitall failed with error code: ", local_ierr
end if

deallocate(c_requests)
deallocate(c_statuses)
end subroutine MPI_Waitall_proc

subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror)
use iso_c_binding, only: c_int, c_ptr
Expand Down
22 changes: 14 additions & 8 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ function c_mpi_request_c2f(request) bind(C, name="MPI_Request_c2f")
integer(c_int) :: c_mpi_request_c2f
end function

function c_mpi_request_f2c(request) bind(C, name="MPI_Request_f2c")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: request
integer(kind=MPI_HANDLE_KIND) :: c_mpi_request_f2c
end function c_mpi_request_f2c

function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: datatype
Expand All @@ -44,7 +50,7 @@ function c_mpi_status_c2f(c_status, f_status) bind(C, name="MPI_Status_c2f")
integer(c_int) :: f_status(*) ! assumed-size array
integer(c_int) :: c_mpi_status_c2f
end function c_mpi_status_c2f

function c_mpi_info_f2c(info_f) bind(C, name="get_c_info_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: info_f
Expand Down Expand Up @@ -189,13 +195,13 @@ function c_mpi_recv(buf, count, c_dtype, source, tag, c_comm, status) bind(C, na
integer(c_int) :: c_mpi_recv
end function c_mpi_recv

subroutine c_mpi_waitall(count, array_of_requests, array_of_statuses, ierror) bind(C, name="mpi_waitall_wrapper")
use iso_c_binding, only: c_int
integer(c_int), intent(in) :: count
integer(c_int), intent(inout) :: array_of_requests(count)
integer(c_int) :: array_of_statuses(*)
integer(c_int), optional, intent(out) :: ierror
end subroutine
function c_mpi_waitall(count, requests, statuses) bind(C, name="MPI_Waitall")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: count
integer(kind=MPI_HANDLE_KIND), dimension(*), intent(inout) :: requests
type(c_ptr), dimension(*), intent(out) :: statuses
integer(c_int) :: c_mpi_waitall
end function c_mpi_waitall

function c_mpi_ssend(buf, count, datatype, dest, tag, comm) bind(C, name="MPI_Ssend")
use iso_c_binding, only: c_int, c_double, c_ptr
Expand Down
27 changes: 0 additions & 27 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,3 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) {
void* get_c_mpi_inplace_from_fortran(double sendbuf) {
return MPI_IN_PLACE;
}

void mpi_waitall_wrapper(int *count, int *array_of_requests_f,
int *array_of_statuses_f, int *ierror) {
MPI_Request *array_of_requests;
MPI_Status *array_of_statuses;
array_of_requests = (MPI_Request *)malloc((*count) * sizeof(MPI_Request));
array_of_statuses = (MPI_Status *)malloc((*count) * sizeof(MPI_Status));
if (array_of_requests == NULL || array_of_statuses == NULL) {
*ierror = MPI_ERR_NO_MEM;
return;
}
for (int i = 0; i < *count; i++) {
array_of_requests[i] = MPI_Request_f2c(array_of_requests_f[i]);
}

*ierror = MPI_Waitall(*count, array_of_requests, array_of_statuses);
for (int i = 0; i < *count; i++) {
array_of_requests_f[i] = MPI_Request_c2f(array_of_requests[i]);
}

for (int i = 0; i < *count; i++) {
MPI_Status_c2f(&array_of_statuses[i], &array_of_statuses_f[i * MPI_STATUS_SIZE]);
}

free(array_of_requests);
free(array_of_statuses);
}
Loading