|
| 1 | +#ifndef MPIMCI |
| 2 | +#define MPIMCI |
| 3 | + |
| 4 | + |
| 5 | +#include <mpi.h> |
| 6 | +#include "MCIntegrator.hpp" |
| 7 | + |
| 8 | +namespace MPIMCI |
| 9 | +{ |
| 10 | + int init() // return mpi rank of process |
| 11 | + { |
| 12 | + if (MPI::Is_initialized()) throw std::runtime_error("MPI already initialized!"); |
| 13 | + MPI::Init(); |
| 14 | + MPI::COMM_WORLD.Set_errhandler(MPI::ERRORS_THROW_EXCEPTIONS); |
| 15 | + return MPI::COMM_WORLD.Get_rank(); |
| 16 | + } |
| 17 | + |
| 18 | + void integrate(MCI * const mci, const long &Nmc, double * average, double * error, bool findMRT2step=true, bool initialdecorrelation=true, bool use_mpi=true) // by setting use_mpi to false you can use this without requiring MPI |
| 19 | + { |
| 20 | + if (use_mpi) { |
| 21 | + // make sure the user has MPI in the correct state |
| 22 | + if (!MPI::Is_initialized()) throw std::runtime_error("MPI not initialized!"); |
| 23 | + if (MPI::Is_finalized()) throw std::runtime_error("MPI already finalized!"); |
| 24 | + |
| 25 | + const int myrank = MPI::COMM_WORLD.Get_rank(); |
| 26 | + const int nranks = MPI::COMM_WORLD.Get_size(); |
| 27 | + |
| 28 | + // the results are stored in myAverage/Error and then reduced into average/error for root process |
| 29 | + double myAverage[mci->getNObsDim()]; |
| 30 | + double myError[mci->getNObsDim()]; |
| 31 | + |
| 32 | + mci->integrate(Nmc, myAverage, myError, findMRT2step, initialdecorrelation); |
| 33 | + |
| 34 | + for (int i=0; i<mci->getNObsDim(); ++i) { |
| 35 | + MPI::COMM_WORLD.Reduce(&myAverage[i], &average[i], 1, MPI::DOUBLE, MPI::SUM, 0); |
| 36 | + |
| 37 | + myError[i] *= myError[i]; |
| 38 | + MPI::COMM_WORLD.Reduce(&myError[i], &error[i], 1, MPI::DOUBLE, MPI::SUM, 0); |
| 39 | + |
| 40 | + if (myrank == 0) { |
| 41 | + average[i] /= nranks; |
| 42 | + error[i] = sqrt(error[i]) / nranks; |
| 43 | + } |
| 44 | + } |
| 45 | + } |
| 46 | + else { |
| 47 | + mci->integrate(Nmc, average, error, findMRT2step, initialdecorrelation); // regular single thread call |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + void finalize(MCI * mci) |
| 52 | + { |
| 53 | + // make sure the user has MPI in the correct state |
| 54 | + if (!MPI::Is_initialized()) throw std::runtime_error("MPI not initialized!"); |
| 55 | + if (MPI::Is_finalized()) throw std::runtime_error("MPI already finalized!"); |
| 56 | + |
| 57 | + delete mci; |
| 58 | + MPI::Finalize(); |
| 59 | + } |
| 60 | +}; |
| 61 | + |
| 62 | +#endif |
0 commit comments