@@ -15,187 +15,39 @@ using complex = std::complex<double>;
1515// Namespace for the diagonalization solver
1616namespace hsolver
1717{
18- // this struct is used for collecting matrices from all processes to root process
19- template <typename T>
20- struct Matrix_g
21- {
22- std::shared_ptr<T> p;
23- size_t row;
24- size_t col;
25- std::shared_ptr<int > desc;
26- };
27-
2818// Initialize the DecomposedState variable for real and complex numbers
2919template <typename T>
3020int DiagoCusolver<T>::DecomposedState = 0 ;
3121
3222template <typename T>
33- DiagoCusolver<T>::DiagoCusolver(const Parallel_Orbitals* ParaV )
23+ DiagoCusolver<T>::DiagoCusolver()
3424{
35- this ->ParaV = ParaV;
3625}
3726
3827template <typename T>
3928DiagoCusolver<T>::~DiagoCusolver ()
4029{
4130}
4231
43- // Wrapper for pdgemr2d and pzgemr2d
44- // static inline void Cpxgemr2d(
45- // const int M, const int N,
46- // double *a, const int ia, const int ja, const int *desca,
47- // double *b, const int ib, const int jb, const int *descb,
48- // const int blacs_ctxt)
49- // {
50- // pdgemr2d_(&M, &N,
51- // a, &ia, &ja, desca,
52- // b, &ib, &jb, descb,
53- // &blacs_ctxt);
54- // }
55- //
56- // static inline void Cpxgemr2d(
57- // const int M, const int N,
58- // complex *a, const int ia, const int ja, const int *desca,
59- // complex *b, const int ib, const int jb, const int *descb,
60- // const int blacs_ctxt)
61- // {
62- // pzgemr2d_(&M, &N,
63- // a, &ia, &ja, desca,
64- // b, &ib, &jb, descb,
65- // &blacs_ctxt);
66- // }
67-
68- // Use Cpxgemr2d to collect matrices from all processes to root process
69- template <typename mat, typename matg>
70- static void gatherMatrix (const int myid, const int root_proc, const mat& mat_l, matg& mat_g)
71- {
72- auto a = mat_l.p ;
73- const int * desca = mat_l.desc ;
74- int ctxt = desca[1 ];
75- int nrows = desca[2 ];
76- int ncols = desca[3 ];
77-
78- if (myid == root_proc)
79- {
80- mat_g.p .reset (new typename std::remove_reference<decltype (*a)>::type[nrows * ncols]);
81- }
82- else
83- {
84- mat_g.p .reset (new typename std::remove_reference<decltype (*a)>::type[1 ]);
85- }
86-
87- // Set descb, which has all elements in the only block in the root process
88- mat_g.desc .reset (new int [9 ]{1 , ctxt, nrows, ncols, nrows, ncols, 0 , 0 , nrows});
89-
90- mat_g.row = nrows;
91- mat_g.col = ncols;
92-
93- Cpxgemr2d (nrows, ncols, a, 1 , 1 , const_cast <int *>(desca), mat_g.p .get (), 1 , 1 , mat_g.desc .get (), ctxt);
94- }
95-
96- // Convert the Psi to a 2D block storage format
97- template <typename T>
98- static void distributePsi (const int * desc_psi, T* psi, T* psi_g)
99- {
100- int ctxt = desc_psi[1 ];
101- int nrows = desc_psi[2 ];
102- int ncols = desc_psi[3 ];
103- int rsrc = desc_psi[6 ];
104- int csrc = desc_psi[7 ];
105-
106- int descg[9 ] = {1 , ctxt, nrows, ncols, nrows, ncols, rsrc, csrc, nrows};
107- int descl[9 ];
108-
109- std::copy (desc_psi, desc_psi + 9 , descl);
110-
111- Cpxgemr2d (nrows, ncols, psi_g, 1 , 1 , descg, psi, 1 , 1 , descl, ctxt);
112- }
113-
11432// Diagonalization function
11533template <typename T>
116- void DiagoCusolver<T>::diag(hamilt::Hamilt<T>* phm_in, psi::Psi<T>& psi, Real* eigenvalue_in)
34+ void DiagoCusolver<T>::diag(
35+ hamilt::MatrixBlock<T>& h_mat,
36+ hamilt::MatrixBlock<T>& s_mat,
37+ psi::Psi<T>& psi,
38+ Real* eigenvalue_in)
11739{
118- // Output the title for the current operation
11940 ModuleBase::TITLE (" DiagoCusolver" , " diag" );
120-
121- // Create matrices for the Hamiltonian and overlap
122- hamilt::MatrixBlock<T> h_mat;
123- hamilt::MatrixBlock<T> s_mat;
124- phm_in->matrix (h_mat, s_mat);
125-
126- #ifdef __MPI
127- // global matrix
128- Matrix_g<T> h_mat_g;
129- Matrix_g<T> s_mat_g;
130-
131- // get the context and process information
132- int ctxt = ParaV->blacs_ctxt ;
133- int nprows = 0 ;
134- int npcols = 0 ;
135- int myprow = 0 ;
136- int mypcol = 0 ;
137- Cblacs_gridinfo (ctxt, &nprows, &npcols, &myprow, &mypcol);
138- const int num_procs = nprows * npcols;
139- const int myid = Cblacs_pnum (ctxt, myprow, mypcol);
140- const int root_proc = Cblacs_pnum (ctxt, ParaV->desc [6 ], ParaV->desc [7 ]);
141- #endif
142-
41+ ModuleBase::timer::tick (" DiagoCusolver" , " cusolver" );
14342 // Allocate memory for eigenvalues
14443 std::vector<double > eigen (PARAM.globalv .nlocal , 0.0 );
145-
146- // Start the timer for the cusolver operation
147- ModuleBase::timer::tick (" DiagoCusolver" , " cusolver" );
148-
149- #ifdef __MPI
150- if (num_procs > 1 )
151- {
152- // gather matrices from processes to root process
153- gatherMatrix (myid, root_proc, h_mat, h_mat_g);
154- gatherMatrix (myid, root_proc, s_mat, s_mat_g);
155- }
156- #endif
157-
158- // Call the dense diagonalization routine
159- #ifdef __MPI
160- if (num_procs > 1 )
161- {
162- MPI_Barrier (MPI_COMM_WORLD);
163- // global psi for distribute
164- int psi_len = myid == root_proc ? h_mat_g.row * h_mat_g.col : 1 ;
165- std::vector<T> psi_g (psi_len);
166- if (myid == root_proc)
167- {
168- this ->dc .Dngvd (h_mat_g.col , h_mat_g.row , h_mat_g.p .get (), s_mat_g.p .get (), eigen.data (), psi_g.data ());
169- }
170-
171- MPI_Barrier (MPI_COMM_WORLD);
172-
173- // broadcast eigenvalues to all processes
174- MPI_Bcast (eigen.data (), PARAM.inp .nbands , MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
175-
176- // distribute psi to all processes
177- distributePsi (this ->ParaV ->desc_wfc , psi.get_pointer (), psi_g.data ());
178- }
179- else
180- {
181- // Be careful that h_mat.row * h_mat.col != psi.get_nbands() * psi.get_nbasis() under multi-k situation
182- std::vector<T> eigenvectors (h_mat.row * h_mat.col );
183- this ->dc .Dngvd (h_mat.row , h_mat.col , h_mat.p , s_mat.p , eigen.data (), eigenvectors.data ());
184- const int size = psi.get_nbands () * psi.get_nbasis ();
185- BlasConnector::copy (size, eigenvectors.data (), 1 , psi.get_pointer (), 1 );
186- }
187- #else
18844 std::vector<T> eigenvectors (h_mat.row * h_mat.col );
18945 this ->dc .Dngvd (h_mat.row , h_mat.col , h_mat.p , s_mat.p , eigen.data (), eigenvectors.data ());
19046 const int size = psi.get_nbands () * psi.get_nbasis ();
19147 BlasConnector::copy (size, eigenvectors.data (), 1 , psi.get_pointer (), 1 );
192- #endif
193- // Stop the timer for the cusolver operation
194- ModuleBase::timer::tick (" DiagoCusolver" , " cusolver" );
195-
196- // Copy the eigenvalues to the output arrays
19748 const int inc = 1 ;
19849 BlasConnector::copy (PARAM.inp .nbands , eigen.data (), inc, eigenvalue_in, inc);
50+ ModuleBase::timer::tick (" DiagoCusolver" , " cusolver" );
19951}
20052
20153// Explicit instantiation of the DiagoCusolver class for real and complex numbers
0 commit comments