Skip to content

Commit 8627dc5

Browse files
authored
Perf: cuSOLVER supports parallel solving of multiple k-point matrices (Useful information to know how to use different GPUs to diagonalize matrices from different k points) (#6464)
* Improve the algorithm for solving eigenvalues using cusolver * add non-mpi version * modify some comments * improve the efficiency for npro ==1 case * fix compilation error
1 parent 37bdf71 commit 8627dc5

File tree

4 files changed

+219
-165
lines changed

4 files changed

+219
-165
lines changed

source/source_hsolver/diago_cusolver.cpp

Lines changed: 8 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -15,187 +15,39 @@ using complex = std::complex<double>;
1515
// Namespace for the diagonalization solver
1616
namespace hsolver
1717
{
18-
// this struct is used for collecting matrices from all processes to root process
19-
template <typename T>
20-
struct Matrix_g
21-
{
22-
std::shared_ptr<T> p;
23-
size_t row;
24-
size_t col;
25-
std::shared_ptr<int> desc;
26-
};
27-
2818
// Initialize the DecomposedState variable for real and complex numbers
2919
template <typename T>
3020
int DiagoCusolver<T>::DecomposedState = 0;
3121

3222
template <typename T>
33-
DiagoCusolver<T>::DiagoCusolver(const Parallel_Orbitals* ParaV)
23+
DiagoCusolver<T>::DiagoCusolver()
3424
{
35-
this->ParaV = ParaV;
3625
}
3726

3827
template <typename T>
3928
DiagoCusolver<T>::~DiagoCusolver()
4029
{
4130
}
4231

43-
// Wrapper for pdgemr2d and pzgemr2d
44-
// static inline void Cpxgemr2d(
45-
// const int M, const int N,
46-
// double *a, const int ia, const int ja, const int *desca,
47-
// double *b, const int ib, const int jb, const int *descb,
48-
// const int blacs_ctxt)
49-
//{
50-
// pdgemr2d_(&M, &N,
51-
// a, &ia, &ja, desca,
52-
// b, &ib, &jb, descb,
53-
// &blacs_ctxt);
54-
//}
55-
//
56-
// static inline void Cpxgemr2d(
57-
// const int M, const int N,
58-
// complex *a, const int ia, const int ja, const int *desca,
59-
// complex *b, const int ib, const int jb, const int *descb,
60-
// const int blacs_ctxt)
61-
//{
62-
// pzgemr2d_(&M, &N,
63-
// a, &ia, &ja, desca,
64-
// b, &ib, &jb, descb,
65-
// &blacs_ctxt);
66-
//}
67-
68-
// Use Cpxgemr2d to collect matrices from all processes to root process
69-
template <typename mat, typename matg>
70-
static void gatherMatrix(const int myid, const int root_proc, const mat& mat_l, matg& mat_g)
71-
{
72-
auto a = mat_l.p;
73-
const int* desca = mat_l.desc;
74-
int ctxt = desca[1];
75-
int nrows = desca[2];
76-
int ncols = desca[3];
77-
78-
if (myid == root_proc)
79-
{
80-
mat_g.p.reset(new typename std::remove_reference<decltype(*a)>::type[nrows * ncols]);
81-
}
82-
else
83-
{
84-
mat_g.p.reset(new typename std::remove_reference<decltype(*a)>::type[1]);
85-
}
86-
87-
// Set descb, which has all elements in the only block in the root process
88-
mat_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
89-
90-
mat_g.row = nrows;
91-
mat_g.col = ncols;
92-
93-
Cpxgemr2d(nrows, ncols, a, 1, 1, const_cast<int*>(desca), mat_g.p.get(), 1, 1, mat_g.desc.get(), ctxt);
94-
}
95-
96-
// Convert the Psi to a 2D block storage format
97-
template <typename T>
98-
static void distributePsi(const int* desc_psi, T* psi, T* psi_g)
99-
{
100-
int ctxt = desc_psi[1];
101-
int nrows = desc_psi[2];
102-
int ncols = desc_psi[3];
103-
int rsrc = desc_psi[6];
104-
int csrc = desc_psi[7];
105-
106-
int descg[9] = {1, ctxt, nrows, ncols, nrows, ncols, rsrc, csrc, nrows};
107-
int descl[9];
108-
109-
std::copy(desc_psi, desc_psi + 9, descl);
110-
111-
Cpxgemr2d(nrows, ncols, psi_g, 1, 1, descg, psi, 1, 1, descl, ctxt);
112-
}
113-
11432
// Diagonalization function
11533
template <typename T>
116-
void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in)
34+
void DiagoCusolver<T>::diag(
35+
hamilt::MatrixBlock<T>& h_mat,
36+
hamilt::MatrixBlock<T>& s_mat,
37+
psi::Psi<T>& psi,
38+
Real* eigenvalue_in)
11739
{
118-
// Output the title for the current operation
11940
ModuleBase::TITLE("DiagoCusolver", "diag");
120-
121-
// Create matrices for the Hamiltonian and overlap
122-
hamilt::MatrixBlock<T> h_mat;
123-
hamilt::MatrixBlock<T> s_mat;
124-
phm_in->matrix(h_mat, s_mat);
125-
126-
#ifdef __MPI
127-
// global matrix
128-
Matrix_g<T> h_mat_g;
129-
Matrix_g<T> s_mat_g;
130-
131-
// get the context and process information
132-
int ctxt = ParaV->blacs_ctxt;
133-
int nprows = 0;
134-
int npcols = 0;
135-
int myprow = 0;
136-
int mypcol = 0;
137-
Cblacs_gridinfo(ctxt, &nprows, &npcols, &myprow, &mypcol);
138-
const int num_procs = nprows * npcols;
139-
const int myid = Cblacs_pnum(ctxt, myprow, mypcol);
140-
const int root_proc = Cblacs_pnum(ctxt, ParaV->desc[6], ParaV->desc[7]);
141-
#endif
142-
41+
ModuleBase::timer::tick("DiagoCusolver", "cusolver");
14342
// Allocate memory for eigenvalues
14443
std::vector<double> eigen(PARAM.globalv.nlocal, 0.0);
145-
146-
// Start the timer for the cusolver operation
147-
ModuleBase::timer::tick("DiagoCusolver", "cusolver");
148-
149-
#ifdef __MPI
150-
if (num_procs > 1)
151-
{
152-
// gather matrices from processes to root process
153-
gatherMatrix(myid, root_proc, h_mat, h_mat_g);
154-
gatherMatrix(myid, root_proc, s_mat, s_mat_g);
155-
}
156-
#endif
157-
158-
// Call the dense diagonalization routine
159-
#ifdef __MPI
160-
if (num_procs > 1)
161-
{
162-
MPI_Barrier(MPI_COMM_WORLD);
163-
// global psi for distribute
164-
int psi_len = myid == root_proc ? h_mat_g.row * h_mat_g.col : 1;
165-
std::vector<T> psi_g(psi_len);
166-
if (myid == root_proc)
167-
{
168-
this->dc.Dngvd(h_mat_g.col, h_mat_g.row, h_mat_g.p.get(), s_mat_g.p.get(), eigen.data(), psi_g.data());
169-
}
170-
171-
MPI_Barrier(MPI_COMM_WORLD);
172-
173-
// broadcast eigenvalues to all processes
174-
MPI_Bcast(eigen.data(), PARAM.inp.nbands, MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
175-
176-
// distribute psi to all processes
177-
distributePsi(this->ParaV->desc_wfc, psi.get_pointer(), psi_g.data());
178-
}
179-
else
180-
{
181-
// Be careful that h_mat.row * h_mat.col != psi.get_nbands() * psi.get_nbasis() under multi-k situation
182-
std::vector<T> eigenvectors(h_mat.row * h_mat.col);
183-
this->dc.Dngvd(h_mat.row, h_mat.col, h_mat.p, s_mat.p, eigen.data(), eigenvectors.data());
184-
const int size = psi.get_nbands() * psi.get_nbasis();
185-
BlasConnector::copy(size, eigenvectors.data(), 1, psi.get_pointer(), 1);
186-
}
187-
#else
18844
std::vector<T> eigenvectors(h_mat.row * h_mat.col);
18945
this->dc.Dngvd(h_mat.row, h_mat.col, h_mat.p, s_mat.p, eigen.data(), eigenvectors.data());
19046
const int size = psi.get_nbands() * psi.get_nbasis();
19147
BlasConnector::copy(size, eigenvectors.data(), 1, psi.get_pointer(), 1);
192-
#endif
193-
// Stop the timer for the cusolver operation
194-
ModuleBase::timer::tick("DiagoCusolver", "cusolver");
195-
196-
// Copy the eigenvalues to the output arrays
19748
const int inc = 1;
19849
BlasConnector::copy(PARAM.inp.nbands, eigen.data(), inc, eigenvalue_in, inc);
50+
ModuleBase::timer::tick("DiagoCusolver", "cusolver");
19951
}
20052

20153
// Explicit instantiation of the DiagoCusolver class for real and complex numbers

source/source_hsolver/diago_cusolver.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,18 @@ class DiagoCusolver
1717
private:
1818
// Real is the real part of the complex type T
1919
using Real = typename GetTypeReal<T>::type;
20-
Parallel_Orbitals const * ParaV;
2120

2221
public:
2322

24-
DiagoCusolver(const Parallel_Orbitals* ParaV = nullptr);
23+
DiagoCusolver();
2524
~DiagoCusolver();
2625

2726
// Override the diag function for CUSOLVER diagonalization
28-
void diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in);
27+
void diag(
28+
hamilt::MatrixBlock<T>& h_mat,
29+
hamilt::MatrixBlock<T>& s_mat,
30+
psi::Psi<T>& psi,
31+
Real* eigenvalue_in);
2932

3033
// Static variable to keep track of the decomposition state
3134
static int DecomposedState;

0 commit comments

Comments
 (0)