Skip to content

Commit f85bb8e

Browse files
committed
Make MPI wrapper usable in serial
1 parent 54909e7 commit f85bb8e

File tree

4 files changed

+64
-53
lines changed

4 files changed

+64
-53
lines changed

examples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99

1010
## Example 2
1111

12-
`ex2/`: like ex2, but using the MPI wrapper for parallel integration.
12+
`ex2/`: like ex2, but using the MPI wrapper for parallel integration. Pass the wanted number of threads to run.sh.

examples/ex2/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <fstream>
66

77
#include "MCIntegrator.hpp"
8-
#include "MCIWrapperMPI.hpp"
8+
#include "MPIMCI.hpp"
99

1010

1111
// Observable functions

src/MCIWrapperMPI.hpp

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/MPIMCI.hpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#ifndef MPIMCI
2+
#define MPIMCI
3+
4+
5+
#include <mpi.h>
6+
#include "MCIntegrator.hpp"
7+
8+
namespace MPIMCI
9+
{
10+
int init() // return mpi rank of process
11+
{
12+
if (MPI::Is_initialized()) throw std::runtime_error("MPI already initialized!");
13+
MPI::Init();
14+
MPI::COMM_WORLD.Set_errhandler(MPI::ERRORS_THROW_EXCEPTIONS);
15+
return MPI::COMM_WORLD.Get_rank();
16+
}
17+
18+
void integrate(MCI * const mci, const long &Nmc, double * average, double * error, bool findMRT2step=true, bool initialdecorrelation=true, bool use_mpi=true) // by setting use_mpi to false you can use this without requiring MPI
19+
{
20+
if (use_mpi) {
21+
// make sure the user has MPI in the correct state
22+
if (!MPI::Is_initialized()) throw std::runtime_error("MPI not initialized!");
23+
if (MPI::Is_finalized()) throw std::runtime_error("MPI already finalized!");
24+
25+
const int myrank = MPI::COMM_WORLD.Get_rank();
26+
const int nranks = MPI::COMM_WORLD.Get_size();
27+
28+
// the results are stored in myAverage/Error and then reduced into average/error for root process
29+
double myAverage[mci->getNObsDim()];
30+
double myError[mci->getNObsDim()];
31+
32+
mci->integrate(Nmc, myAverage, myError, findMRT2step, initialdecorrelation);
33+
34+
for (int i=0; i<mci->getNObsDim(); ++i) {
35+
MPI::COMM_WORLD.Reduce(&myAverage[i], &average[i], 1, MPI::DOUBLE, MPI::SUM, 0);
36+
37+
myError[i] *= myError[i];
38+
MPI::COMM_WORLD.Reduce(&myError[i], &error[i], 1, MPI::DOUBLE, MPI::SUM, 0);
39+
40+
if (myrank == 0) {
41+
average[i] /= nranks;
42+
error[i] = sqrt(error[i]) / nranks;
43+
}
44+
}
45+
}
46+
else {
47+
mci->integrate(Nmc, average, error, findMRT2step, initialdecorrelation); // regular single thread call
48+
}
49+
}
50+
51+
void finalize(MCI * mci)
52+
{
53+
// make sure the user has MPI in the correct state
54+
if (!MPI::Is_initialized()) throw std::runtime_error("MPI not initialized!");
55+
if (MPI::Is_finalized()) throw std::runtime_error("MPI already finalized!");
56+
57+
delete mci;
58+
MPI::Finalize();
59+
}
60+
};
61+
62+
#endif

0 commit comments

Comments
 (0)