diff --git a/src/mpi.f90 b/src/mpi.f90 index 7df44db..8d15df9 100644 --- a/src/mpi.f90 +++ b/src/mpi.f90 @@ -1,10 +1,11 @@ module mpi implicit none integer, parameter :: MPI_THREAD_FUNNELED = 1 - ! not sure if this is correct really - integer, parameter :: MPI_INTEGER = 2 - integer, parameter :: MPI_REAL4 = 0 - integer, parameter :: MPI_REAL8 = 1 + + integer, parameter :: MPI_INTEGER = -10002 + integer, parameter :: MPI_REAL4 = -10013 + integer, parameter :: MPI_REAL8 = -10014 + integer, parameter :: MPI_COMM_TYPE_SHARED = 1 integer, parameter :: MPI_PROC_NULL = -1 integer, parameter :: MPI_SUCCESS = 0 diff --git a/src/mpi_wrapper.c b/src/mpi_wrapper.c index 15084eb..98a9d2a 100644 --- a/src/mpi_wrapper.c +++ b/src/mpi_wrapper.c @@ -3,11 +3,34 @@ #include #define MPI_STATUS_SIZE 5 + #define FORTRAN_MPI_COMM_WORLD -1000 #define FORTRAN_MPI_INFO_NULL -2000 #define FORTRAN_MPI_IN_PLACE -1002 + #define FORTRAN_MPI_SUM -2300 +#define FORTRAN_MPI_INTEGER -10002 +#define FORTRAN_MPI_REAL4 -10013 +#define FORTRAN_MPI_REAL8 -10014 + + +MPI_Datatype get_c_datatype_from_fortran(int datatype) { + MPI_Datatype c_datatype; + switch (datatype) { + case FORTRAN_MPI_REAL4: + c_datatype = MPI_FLOAT; + break; + case FORTRAN_MPI_REAL8: + c_datatype = MPI_DOUBLE; + break; + case FORTRAN_MPI_INTEGER: + c_datatype = MPI_INT; + break; + } + return c_datatype; +} + MPI_Info get_c_info_from_fortran(int info) { if (info == FORTRAN_MPI_INFO_NULL) { return MPI_INFO_NULL; @@ -34,38 +57,13 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) { void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 2: - datatype = MPI_INT; - break; - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); } void mpi_bcast_real_wrapper(double *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); *ierror = MPI_Bcast(buffer, *count, datatype, *root, comm); } @@ -74,24 +72,8 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype sendtype, recvtype; - switch (*sendtype_f) { - case 2: - sendtype = MPI_INT; - break; - default: - *ierror = -1; - return; - } - - switch (*recvtype_f) { - case 2: - recvtype = MPI_INT; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype sendtype = get_c_datatype_from_fortran(*sendtype_f); + MPI_Datatype recvtype = get_c_datatype_from_fortran(*recvtype_f); *ierror = MPI_Allgather(sendbuf, *sendcount, sendtype, recvbuf, *recvcount, recvtype, comm); @@ -102,30 +84,8 @@ void mpi_allgather_real_wrapper(const double *sendbuf, int *sendcount, int *send int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype sendtype, recvtype; - switch (*sendtype_f) { - case 0: - sendtype = MPI_FLOAT; - break; - case 1: - sendtype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } - - switch (*recvtype_f) { - case 0: - recvtype = MPI_FLOAT; - break; - case 1: - recvtype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype sendtype = get_c_datatype_from_fortran(*sendtype_f); + MPI_Datatype recvtype = get_c_datatype_from_fortran(*recvtype_f); *ierror = MPI_Allgather(sendbuf, *sendcount, sendtype, recvbuf, *recvcount, recvtype, comm); @@ -135,18 +95,7 @@ void mpi_isend_wrapper(const double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *request_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *request_f = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Request request; *ierror = MPI_Isend(buf, *count, datatype, *dest, *tag, comm, &request); @@ -157,18 +106,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f, int *source, int *tag, int *comm_f, int *request_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *request_f = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Request request; *ierror = MPI_Irecv(buf, *count, datatype, *source, *tag, comm, &request); @@ -178,18 +116,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f, void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count, int *datatype_f, int *op_f, int *comm_f, int *ierror) { MPI_Comm comm = get_c_comm_from_fortran(*comm_f); - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); // I'm a little doubtful: as how would it identify that this part // is supposed to be for MPI_SUM? @@ -244,18 +171,7 @@ void mpi_comm_split_type_wrapper(int *comm_f, int *split_type, int *key, void mpi_recv_wrapper(double *buf, int *count, int *datatype_f, int *source, int *tag, int *comm_f, int *status_f, int *ierror) { - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); MPI_Status status; @@ -294,18 +210,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f, void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest, int *tag, int *comm_f, int *ierror) { - MPI_Datatype datatype; - switch (*datatype_f) { - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm); @@ -345,21 +250,8 @@ void mpi_cart_sub_wrapper(int * comm_f, int * rmains_dims, int * newcomm_f, int void mpi_reduce_wrapper(const int* sendbuf, int* recvbuf, int* count, int* datatype_f, int* op_f, int* root, int* comm_f, int* ierror) { - MPI_Datatype datatype; - switch (*datatype_f) { - case 2: - datatype = MPI_INT; - break; - case 0: - datatype = MPI_FLOAT; - break; - case 1: - datatype = MPI_DOUBLE; - break; - default: - *ierror = -1; - return; - } + MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f); + MPI_Op op = get_c_op_from_fortran(*op_f); MPI_Comm comm = get_c_comm_from_fortran(*comm_f); *ierror = MPI_Reduce(sendbuf, recvbuf, *count, datatype, op, *root, comm);