Skip to content

use Fortran code for setting the datatype #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions src/mpi.f90
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ integer(kind=MPI_HANDLE_KIND) function handle_mpi_info_f2c(info_f) result(c_info
end if
end function handle_mpi_info_f2c

integer(kind=MPI_HANDLE_KIND) function handle_mpi_datatype_f2c(datatype_f) result(c_datatype)
use mpi_c_bindings, only: c_mpi_float, c_mpi_double, c_mpi_int
integer, intent(in) :: datatype_f
if (datatype_f == MPI_REAL4) then
c_datatype = c_mpi_float()
else if (datatype_f == MPI_REAL8 .OR. datatype_f == MPI_DOUBLE_PRECISION) then
c_datatype = c_mpi_double()
else if (datatype_f == MPI_INTEGER) then
c_datatype = c_mpi_int()
end if
end function

subroutine MPI_Init_proc(ierr)
use mpi_c_bindings, only: c_mpi_init
use iso_c_binding, only : c_int, c_ptr, c_null_ptr
Expand Down Expand Up @@ -241,7 +253,7 @@ subroutine MPI_Comm_size_proc(comm, size, ierror)
end subroutine

subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_datatype_f2c
use mpi_c_bindings, only: c_mpi_bcast
use iso_c_binding, only: c_int, c_ptr, c_loc
integer, target :: buffer
integer, intent(in) :: count, root
Expand All @@ -254,7 +266,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror)

c_comm = handle_mpi_comm_f2c(comm)

c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
buffer_ptr = c_loc(buffer)
local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm)

Expand All @@ -268,7 +280,7 @@ subroutine MPI_Bcast_int_scalar(buffer, count, datatype, root, comm, ierror)
end subroutine MPI_Bcast_int_scalar

subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_datatype_f2c, c_mpi_comm_world
use mpi_c_bindings, only: c_mpi_bcast, c_mpi_comm_f2c, c_mpi_comm_world
use iso_c_binding, only: c_int, c_ptr, c_loc
real(8), dimension(:, :), target :: buffer
integer, intent(in) :: count, root
Expand All @@ -281,7 +293,7 @@ subroutine MPI_Bcast_real_2D(buffer, count, datatype, root, comm, ierror)

c_comm = handle_mpi_comm_f2c(comm)

c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
buffer_ptr = c_loc(buffer)
local_ierr = c_mpi_bcast(buffer_ptr, count, c_datatype, root, c_comm)

Expand All @@ -296,7 +308,7 @@ end subroutine MPI_Bcast_real_2D

subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allgather_int, c_mpi_datatype_f2c
use mpi_c_bindings, only: c_mpi_allgather_int
integer, dimension(:), intent(in), target :: sendbuf
integer, dimension(:, :), intent(out), target :: recvbuf
integer, intent(in) :: sendcount, recvcount
Expand All @@ -310,8 +322,8 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r

c_comm = handle_mpi_comm_f2c(comm)

c_sendtype = c_mpi_datatype_f2c(sendtype)
c_recvtype = c_mpi_datatype_f2c(recvtype)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
local_ierr = c_mpi_allgather_int(sendbuf_ptr, sendcount, c_sendtype, recvbuf_ptr, recvcount, c_recvtype, c_comm)
Expand All @@ -327,7 +339,7 @@ subroutine MPI_Allgather_int(sendbuf, sendcount, sendtype, recvbuf, recvcount, r

subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allgather_real, c_mpi_datatype_f2c
use mpi_c_bindings, only: c_mpi_allgather_real
real(8), dimension(:), intent(in), target :: sendbuf
real(8), dimension(:, :), intent(out), target :: recvbuf
integer, intent(in) :: sendcount, recvcount
Expand All @@ -341,8 +353,8 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount,

c_comm = handle_mpi_comm_f2c(comm)

c_sendtype = c_mpi_datatype_f2c(sendtype)
c_recvtype = c_mpi_datatype_f2c(recvtype)
c_sendtype = handle_mpi_datatype_f2c(sendtype)
c_recvtype = handle_mpi_datatype_f2c(recvtype)
sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
local_ierr = c_mpi_allgather_real(sendbuf_ptr, sendcount, c_sendtype, recvbuf_ptr, recvcount, c_recvtype, c_comm)
Expand All @@ -358,7 +370,7 @@ subroutine MPI_Allgather_real(sendbuf, sendcount, sendtype, recvbuf, recvcount,

subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_request_c2f
use mpi_c_bindings, only: c_mpi_isend, c_mpi_request_c2f
real(8), dimension(:, :), intent(in), target :: buf
integer, intent(in) :: count, dest, tag
integer, intent(in) :: datatype
Expand All @@ -370,7 +382,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror)
integer(c_int) :: local_ierr

buf_ptr = c_loc(buf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)

c_comm = handle_mpi_comm_f2c(comm)

Expand All @@ -389,7 +401,7 @@ subroutine MPI_Isend_2d(buf, count, datatype, dest, tag, comm, request, ierror)

subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_isend, c_mpi_datatype_f2c, c_mpi_request_c2f
use mpi_c_bindings, only: c_mpi_isend, c_mpi_request_c2f
real(8), dimension(:, :, :), intent(in), target :: buf
integer, intent(in) :: count, dest, tag
integer, intent(in) :: datatype
Expand All @@ -401,7 +413,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror)
integer(c_int) :: local_ierr

buf_ptr = c_loc(buf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)

c_comm = handle_mpi_comm_f2c(comm)

Expand All @@ -420,7 +432,7 @@ subroutine MPI_Isend_3d(buf, count, datatype, dest, tag, comm, request, ierror)

subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_irecv, c_mpi_datatype_f2c, c_mpi_request_c2f
use mpi_c_bindings, only: c_mpi_irecv, c_mpi_request_c2f
real(8), dimension(:,:) :: buf
integer, intent(in) :: count, source, tag
integer, intent(in) :: datatype
Expand All @@ -434,7 +446,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr

c_comm = handle_mpi_comm_f2c(comm)

c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
local_ierr = c_mpi_irecv(buf, count, c_datatype, source, tag, c_comm, c_request)
request = c_mpi_request_c2f(c_request)

Expand All @@ -449,7 +461,7 @@ subroutine MPI_Irecv_proc(buf, count, datatype, source, tag, comm, request, ierr

subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c
real(8), intent(in), target :: sendbuf
real(8), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -464,7 +476,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier
sendbuf_ptr = c_loc(sendbuf)
end if
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)
Expand All @@ -482,7 +494,7 @@ subroutine MPI_Allreduce_scalar(sendbuf, recvbuf, count, datatype, op, comm, ier

subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c
real(8), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -498,7 +510,7 @@ subroutine MPI_Allreduce_1D_recv_proc(sendbuf, recvbuf, count, datatype, op, com
end if

recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)
Expand All @@ -516,7 +528,7 @@ end subroutine MPI_Allreduce_1D_recv_proc

subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, c_mpi_in_place_f2c
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_in_place_f2c
real(8), dimension(:), intent(in), target :: sendbuf
real(8), dimension(:), intent(out), target :: recvbuf
integer, intent(in) :: count, datatype, op, comm
Expand All @@ -527,7 +539,7 @@ subroutine MPI_Allreduce_1D_real_proc(sendbuf, recvbuf, count, datatype, op, com

sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)
Expand All @@ -545,7 +557,7 @@ end subroutine MPI_Allreduce_1D_real_proc

subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_datatype_f2c, &
use mpi_c_bindings, only: c_mpi_allreduce, &
c_mpi_comm_f2c, c_mpi_in_place_f2c, c_mpi_comm_world
integer, dimension(:), intent(in), target :: sendbuf
integer, dimension(:), intent(out), target :: recvbuf
Expand All @@ -557,7 +569,7 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm

sendbuf_ptr = c_loc(sendbuf)
recvbuf_ptr = c_loc(recvbuf)
c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
c_op = handle_mpi_op_f2c(op)

c_comm = handle_mpi_comm_f2c(comm)
Expand Down Expand Up @@ -653,7 +665,7 @@ end subroutine MPI_Comm_split_type_proc

subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, status, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_recv, c_mpi_datatype_f2c, c_mpi_status_c2f
use mpi_c_bindings, only: c_mpi_recv, c_mpi_status_c2f
real(8), dimension(*), intent(inout), target :: buf
integer, intent(in) :: count, source, tag, datatype, comm
integer, intent(out) :: status(MPI_STATUS_SIZE)
Expand All @@ -665,7 +677,7 @@ subroutine MPI_Recv_StatusArray_proc(buf, count, datatype, source, tag, comm, st
integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status

! Convert Fortran handles to C handles.
c_dtype = c_mpi_datatype_f2c(datatype)
c_dtype = handle_mpi_datatype_f2c(datatype)

c_comm = handle_mpi_comm_f2c(comm)

Expand All @@ -690,7 +702,7 @@ end subroutine MPI_Recv_StatusArray_proc

subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, status, ierror)
use iso_c_binding, only: c_int, c_ptr, c_loc
use mpi_c_bindings, only: c_mpi_recv, c_mpi_datatype_f2c, c_mpi_status_c2f
use mpi_c_bindings, only: c_mpi_recv, c_mpi_status_c2f
real(8), dimension(*), intent(inout), target :: buf
integer, intent(in) :: count, source, tag, datatype, comm
integer, intent(out) :: status
Expand All @@ -702,7 +714,7 @@ subroutine MPI_Recv_StatusIgnore_proc(buf, count, datatype, source, tag, comm, s
integer(c_int), dimension(MPI_STATUS_SIZE), target :: tmp_status

! Convert Fortran handles to C handles.
c_dtype = c_mpi_datatype_f2c(datatype)
c_dtype = handle_mpi_datatype_f2c(datatype)

c_comm = handle_mpi_comm_f2c(comm)

Expand Down Expand Up @@ -769,7 +781,7 @@ end subroutine MPI_Waitall_proc

subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror)
use iso_c_binding, only: c_int, c_ptr
use mpi_c_bindings, only: c_mpi_ssend, c_mpi_datatype_f2c
use mpi_c_bindings, only: c_mpi_ssend
real(8), dimension(*), intent(in) :: buf
integer, intent(in) :: count, dest, tag
integer, intent(in) :: datatype
Expand All @@ -778,7 +790,7 @@ subroutine MPI_Ssend_proc(buf, count, datatype, dest, tag, comm, ierror)
integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_comm
integer :: local_ierr

c_datatype = c_mpi_datatype_f2c(datatype)
c_datatype = handle_mpi_datatype_f2c(datatype)
c_comm = handle_mpi_comm_f2c(comm)
local_ierr = c_mpi_ssend(buf, count, c_datatype, dest, tag, c_comm)
end subroutine
Expand Down Expand Up @@ -924,7 +936,7 @@ subroutine MPI_Cart_sub_proc (comm, remain_dims, newcomm, ierror)
end subroutine

subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, comm, ierror)
use mpi_c_bindings, only: c_mpi_reduce, c_mpi_datatype_f2c
use mpi_c_bindings, only: c_mpi_reduce
use iso_c_binding, only: c_int, c_ptr, c_loc
integer, target, intent(in) :: sendbuf
integer, target, intent(out) :: recvbuf
Expand All @@ -937,7 +949,7 @@ subroutine MPI_Reduce_scalar_int(sendbuf, recvbuf, count, datatype, op, root, co

c_comm = handle_mpi_comm_f2c(comm)

c_dtype = c_mpi_datatype_f2c(datatype)
c_dtype = handle_mpi_datatype_f2c(datatype)
c_op = handle_mpi_op_f2c(op)

! Pass pointer to the actual data
Expand Down
16 changes: 11 additions & 5 deletions src/mpi_c_bindings.f90
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,17 @@ function c_mpi_sum() bind(C, name="get_c_MPI_SUM")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_sum
end function c_mpi_sum

function c_mpi_datatype_f2c(datatype) bind(C, name="get_c_datatype_from_fortran")
use iso_c_binding, only: c_int, c_ptr
integer(c_int), value :: datatype
integer(kind=MPI_HANDLE_KIND) :: c_mpi_datatype_f2c
end function c_mpi_datatype_f2c
function c_mpi_float() bind(C, name="get_c_MPI_FLOAT")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_float
end function

function c_mpi_double() bind(C, name="get_c_MPI_DOUBLE")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_double
end function

function c_mpi_int() bind(C, name="get_c_MPI_INT")
integer(kind=MPI_HANDLE_KIND) :: c_mpi_int
end function

function c_mpi_op_f2c(op_f) bind(C, name="MPI_Op_f2c")
use iso_c_binding, only: c_ptr, c_int
Expand Down
43 changes: 11 additions & 32 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
#include <mpi.h>
#include <stdlib.h>
#include <stdio.h>

#define MPI_STATUS_SIZE 5

#define FORTRAN_MPI_COMM_WORLD -1000
#define FORTRAN_MPI_INFO_NULL -2000
#define FORTRAN_MPI_IN_PLACE -1002

#define FORTRAN_MPI_SUM -2300

#define FORTRAN_MPI_INTEGER -10002
#define FORTRAN_MPI_DOUBLE_PRECISION -10004
#define FORTRAN_MPI_REAL4 -10013
#define FORTRAN_MPI_REAL8 -10014


MPI_Datatype get_c_datatype_from_fortran(int datatype) {
MPI_Datatype c_datatype;
switch (datatype) {
case FORTRAN_MPI_REAL4:
c_datatype = MPI_FLOAT;
break;
case FORTRAN_MPI_REAL8:
case FORTRAN_MPI_DOUBLE_PRECISION:
c_datatype = MPI_DOUBLE;
break;
case FORTRAN_MPI_INTEGER:
c_datatype = MPI_INT;
break;
}
return c_datatype;

MPI_Datatype get_c_MPI_DOUBLE() {
return MPI_DOUBLE;
}

MPI_Datatype get_c_MPI_FLOAT() {
return MPI_FLOAT;
}

MPI_Datatype get_c_MPI_INT() {
return MPI_INT;
}

MPI_Info get_c_MPI_INFO_NULL() {
Expand Down