Skip to content

simplify MPI datatype evaluation #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/mpi.f90
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module mpi
implicit none
integer, parameter :: MPI_THREAD_FUNNELED = 1
! not sure if this is correct really
integer, parameter :: MPI_INTEGER = 2
integer, parameter :: MPI_REAL4 = 0
integer, parameter :: MPI_REAL8 = 1

integer, parameter :: MPI_INTEGER = -10002
integer, parameter :: MPI_REAL4 = -10013
integer, parameter :: MPI_REAL8 = -10014

integer, parameter :: MPI_COMM_TYPE_SHARED = 1
integer, parameter :: MPI_PROC_NULL = -1
integer, parameter :: MPI_SUCCESS = 0
Expand Down
180 changes: 36 additions & 144 deletions src/mpi_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,34 @@
#include <stdio.h>

#define MPI_STATUS_SIZE 5

#define FORTRAN_MPI_COMM_WORLD -1000
#define FORTRAN_MPI_INFO_NULL -2000
#define FORTRAN_MPI_IN_PLACE -1002

#define FORTRAN_MPI_SUM -2300

#define FORTRAN_MPI_INTEGER -10002
#define FORTRAN_MPI_REAL4 -10013
#define FORTRAN_MPI_REAL8 -10014


MPI_Datatype get_c_datatype_from_fortran(int datatype) {
MPI_Datatype c_datatype;
switch (datatype) {
case FORTRAN_MPI_REAL4:
c_datatype = MPI_FLOAT;
break;
case FORTRAN_MPI_REAL8:
c_datatype = MPI_DOUBLE;
break;
case FORTRAN_MPI_INTEGER:
c_datatype = MPI_INT;
break;
}
return c_datatype;
}

MPI_Info get_c_info_from_fortran(int info) {
if (info == FORTRAN_MPI_INFO_NULL) {
return MPI_INFO_NULL;
Expand All @@ -34,38 +57,13 @@ MPI_Comm get_c_comm_from_fortran(int comm_f) {

void mpi_bcast_int_wrapper(int *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype;
switch (*datatype_f) {
case 2:
datatype = MPI_INT;
break;
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);
*ierror = MPI_Bcast(buffer, *count, datatype, *root, comm);
}

void mpi_bcast_real_wrapper(double *buffer, int *count, int *datatype_f, int *root, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);
*ierror = MPI_Bcast(buffer, *count, datatype, *root, comm);
}

Expand All @@ -74,24 +72,8 @@ void mpi_allgather_int_wrapper(const int *sendbuf, int *sendcount, int *sendtype
int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);

MPI_Datatype sendtype, recvtype;
switch (*sendtype_f) {
case 2:
sendtype = MPI_INT;
break;
default:
*ierror = -1;
return;
}

switch (*recvtype_f) {
case 2:
recvtype = MPI_INT;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype sendtype = get_c_datatype_from_fortran(*sendtype_f);
MPI_Datatype recvtype = get_c_datatype_from_fortran(*recvtype_f);

*ierror = MPI_Allgather(sendbuf, *sendcount, sendtype,
recvbuf, *recvcount, recvtype, comm);
Expand All @@ -102,30 +84,8 @@ void mpi_allgather_real_wrapper(const double *sendbuf, int *sendcount, int *send
int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);

MPI_Datatype sendtype, recvtype;
switch (*sendtype_f) {
case 0:
sendtype = MPI_FLOAT;
break;
case 1:
sendtype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}

switch (*recvtype_f) {
case 0:
recvtype = MPI_FLOAT;
break;
case 1:
recvtype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype sendtype = get_c_datatype_from_fortran(*sendtype_f);
MPI_Datatype recvtype = get_c_datatype_from_fortran(*recvtype_f);

*ierror = MPI_Allgather(sendbuf, *sendcount, sendtype,
recvbuf, *recvcount, recvtype, comm);
Expand All @@ -135,18 +95,7 @@ void mpi_isend_wrapper(const double *buf, int *count, int *datatype_f,
int *dest, int *tag, int *comm_f, int *request_f,
int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*request_f = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Request request;
*ierror = MPI_Isend(buf, *count, datatype, *dest, *tag, comm, &request);
Expand All @@ -157,18 +106,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f,
int *source, int *tag, int *comm_f, int *request_f,
int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*request_f = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Request request;
*ierror = MPI_Irecv(buf, *count, datatype, *source, *tag, comm, &request);
Expand All @@ -178,18 +116,7 @@ void mpi_irecv_wrapper(double *buf, int *count, int *datatype_f,
void mpi_allreduce_wrapper_real(const double *sendbuf, double *recvbuf, int *count,
int *datatype_f, int *op_f, int *comm_f, int *ierror) {
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

// I'm a little doubtful: as how would it identify that this part
// is supposed to be for MPI_SUM?
Expand Down Expand Up @@ -244,18 +171,7 @@ void mpi_comm_split_type_wrapper(int *comm_f, int *split_type, int *key,

void mpi_recv_wrapper(double *buf, int *count, int *datatype_f, int *source,
int *tag, int *comm_f, int *status_f, int *ierror) {
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
MPI_Status status;
Expand Down Expand Up @@ -294,18 +210,7 @@ void mpi_waitall_wrapper(int *count, int *array_of_requests_f,

void mpi_ssend_wrapper(double *buf, int *count, int *datatype_f, int *dest,
int *tag, int *comm_f, int *ierror) {
MPI_Datatype datatype;
switch (*datatype_f) {
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f);
+ MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im getting error

 aditya-trivedi   pot3d    main ≡    FC='gfortran -O3' ./build_and_run_gfortran.sh 
+++ uname
++ [[ Linux == \L\i\n\u\x ]]
++ CC=gcc
++ cd src
++ gcc -I/home/aditya-trivedi/conda_root/envs/mpi/include -c ../../../src/mpi_wrapper.c
../../../src/mpi_wrapper.c: In function 'mpi_ssend_wrapper':
../../../src/mpi_wrapper.c:203:29: error: initialization of 'MPI_Datatype' {aka 'struct ompi_datatype_t *'} from incompatible pointer type 'MPI_Comm' {aka 'struct ompi_communicator_t *'} [-Wincompatible-pointer-types]
  203 |     MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f);
      |                             ^~~~~~~~~~~~~~~~~~~~~~~
 aditya-trivedi   pot3d    main ≡   

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious how CI passed here with this small glitch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you getting this error with OpenMPI or MPICH or both?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. That's a warning not an error:

../../../src/mpi_wrapper.c:203:18: warning: incompatible pointer types initializing 'MPI_Datatype' (aka 'struct ompi_datatype_t *') with an expression of type 'MPI_Comm' (aka 'struct ompi_communicator_t *') [-Wincompatible-pointer-types]
  203 |     MPI_Datatype datatype = get_c_comm_from_fortran(*datatype_f);
      |         

I got it with OpenMPI, I checked it now.

We should've more tests then definitely.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in #60

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch @adit4443ya , thanks, I'll fix it in the next PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you get error on your local machine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you get error on your local machine?

No, only a warning.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's surprising that we only get a warning and no error at all.


MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
*ierror = MPI_Ssend(buf, *count, datatype, *dest, *tag, comm);
Expand Down Expand Up @@ -345,21 +250,8 @@ void mpi_cart_sub_wrapper(int * comm_f, int * rmains_dims, int * newcomm_f, int
void mpi_reduce_wrapper(const int* sendbuf, int* recvbuf, int* count, int* datatype_f,
int* op_f, int* root, int* comm_f, int* ierror)
{
MPI_Datatype datatype;
switch (*datatype_f) {
case 2:
datatype = MPI_INT;
break;
case 0:
datatype = MPI_FLOAT;
break;
case 1:
datatype = MPI_DOUBLE;
break;
default:
*ierror = -1;
return;
}
MPI_Datatype datatype = get_c_datatype_from_fortran(*datatype_f);

MPI_Op op = get_c_op_from_fortran(*op_f);
MPI_Comm comm = get_c_comm_from_fortran(*comm_f);
*ierror = MPI_Reduce(sendbuf, recvbuf, *count, datatype, op, *root, comm);
Expand Down