From e34996e2aa65a4039c89d5621bc4704b379febfa Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Mon, 31 Mar 2025 11:06:00 +0530 Subject: [PATCH 1/3] simplify MPI datatype evaluation --- src/mpi_wrapper.c | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 15084eb..3802624 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -3,11 +3,34 @@ #include #define MPI_STATUS_SIZE 5 + #define FORTRAN_MPI_COMM_WORLD -1000 #define FORTRAN_MPI_INFO_NULL -2000 #define FORTRAN_MPI_IN_PLACE -1002 + #define FORTRAN_MPI_SUM -2300 +#define FORTRAN_MPI_INTEGER 2 +#define FORTRAN_MPI_REAL4 0 +#define FORTRAN_MPI_REAL8 1 + + +MPI_Datatype get_c_datatype_from_fortran(int datatype) { + MPI_Datatype c_datatype; + switch (datatype) { + case FORTRAN_MPI_REAL4: + c_datatype = MPI_FLOAT; + break; + case FORTRAN_MPI_REAL8: + c_datatype = MPI_DOUBLE; + break; + case FORTRAN_MPI_INTEGER: + c_datatype = MPI_INT; + break; + } + return c_datatype; +} + MPI_Info get_c_info_from_fortran(int info) { if (info == FORTRAN_MPI_INFO_NULL) { return MPI_INFO_NULL; From 3db4c35a0a9e5f61b53acd3253254bf81117b23f Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Mon, 31 Mar 2025 11:14:09 +0530 Subject: [PATCH 2/3] use custom function to get MPI_Datatype in C wrapper --- src/mpi_wrapper.c | 157 ++++------------------------------------------ 1 file changed, 13 insertions(+), 144 deletions(-) diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 3802624..16c0344 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -57,38 +57,13 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 2: - datatype = MPI_INT; - break; - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); } void mpi_bcast_real_wrapper(double *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); } @@ -97,24 +72,8 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype sendtype, recvtype; - switch (*sendtype_f) { - case 2: - sendtype = MPI_INT; - break; - default: - *ierror = -1; - return; - } - - switch (*recvtype_f) { - case 2: - recvtype = MPI_INT; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype sendtype = get_c_datatype_from_fortran(*sendtype_f); + MPI_Datatype recvtype = get_c_datatype_from_fortran(*recvtype_f); *ierror = MPI_Allgather(sendbuf, *sendcount, sendtype, recvbuf, *recvcount, recvtype, comm); @@ -125,30 +84,8 @@ void mpi_allgather_real_wrapper(const double *sendbuf, int *sendcount, int *send int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype sendtype, recvtype; - switch (*sendtype_f) { - case 0: - sendtype = MPI_FLOAT; - break; - case 1: - sendtype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } - - switch (*recvtype_f) { - case 0: - recvtype = MPI_FLOAT; - break; - case 1: - recvtype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype sendtype = get_c_datatype_from_fortran(*sendtype_f); + MPI_Datatype recvtype = get_c_datatype_from_fortran(*recvtype_f); *ierror = MPI_Allgather(sendbuf, *sendcount, sendtype, recvbuf, *recvcount, recvtype, comm); @@ -158,18 +95,7 @@ void mpi_isend_wrapper(const double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *request_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *request_f = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Request request; *ierror = MPI_Isend(buf, *count, datatype, *dest, *tag, comm, &request); @@ -180,18 +106,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f, int *source, int *tag, int *comm_f, int *request_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *request_f = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Request request; *ierror = MPI_Irecv(buf, *count, datatype, *source, *tag, comm, &request); @@ -201,18 +116,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f, void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); // I'm a little doubtful: as how would it identify that this part // is supposed to be for MPI_SUM? @@ -267,18 +171,7 @@ void mpi_comm_split_type_wrapper(int *comm_f, int *split_type, int *key, void mpi_recv_wrapper(double *buf, int *count, int *datatype_f, int *source, int *tag, int *comm_f, int *status_f, int *ierror) { - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); MPI_Status status; @@ -317,18 +210,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f, void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *ierror) { - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm); @@ -368,21 +250,8 @@ void mpi_cart_sub_wrapper(int * comm_f, int * rmains_dims, int * newcomm_f, int void mpi_reduce_wrapper(const int* sendbuf, int* recvbuf, int* count, int* datatype_f, int* op_f, int* root, int* comm_f, int* ierror) { - MPI_Datatype datatype; - switch (*datatype_f) { - case 2: - datatype = MPI_INT; - break; - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); + MPI_Op op = get_c_op_from_fortran(*op_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Reduce(sendbuf, recvbuf, *count, datatype, op, *root, comm); From 5ca184f32cc3c6eff322e937159475b111b39717 Mon Sep 17 00:00:00 2001 From: Gaurav Dhingra Date: Mon, 31 Mar 2025 11:17:44 +0530 Subject: [PATCH 3/3] use non-trivial values for MPI_REAL4, MPI_INTEGER, MPI_REAL8 etc. --- src/mpi.f90 | 9 +++++---- src/mpi_wrapper.c | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/mpi.f90 b/src/mpi.f90 index 7df44db..8d15df9 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -1,10 +1,11 @@ module mpi implicit none integer, parameter :: MPI_THREAD_FUNNELED = 1 - ! not sure if this is correct really - integer, parameter :: MPI_INTEGER = 2 - integer, parameter :: MPI_REAL4 = 0 - integer, parameter :: MPI_REAL8 = 1 + + integer, parameter :: MPI_INTEGER = -10002 + integer, parameter :: MPI_REAL4 = -10013 + integer, parameter :: MPI_REAL8 = -10014 + integer, parameter :: MPI_COMM_TYPE_SHARED = 1 integer, parameter :: MPI_PROC_NULL = -1 integer, parameter :: MPI_SUCCESS = 0 diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 16c0344..98a9d2a 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -10,9 +10,9 @@ #define FORTRAN_MPI_SUM -2300 -#define FORTRAN_MPI_INTEGER 2 -#define FORTRAN_MPI_REAL4 0 -#define FORTRAN_MPI_REAL8 1 +#define FORTRAN_MPI_INTEGER -10002 +#define FORTRAN_MPI_REAL4 -10013 +#define FORTRAN_MPI_REAL8 -10014 MPI_Datatype get_c_datatype_from_fortran(int datatype) {