Skip to content

Commit 4b15663

Browse files
authored
Fix: Add MPI synchronization for ekb band energy tensor, and now GPU CI/CD test case 07_NO_EDM_TDDFT_GPU is functional again (#6354)
* Fix ekb MPI sync error in RT-TDDFT * Add back integrate test 07_NO_EDM_TDDFT_GPU * Delete useless myid
1 parent 918d11b commit 4b15663

File tree

4 files changed

+90
-96
lines changed

4 files changed

+90
-96
lines changed

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 74 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#include "esolver_ks_lcao_tddft.h"
22

3+
#include "source_estate/elecstate_tools.h"
34
#include "source_io/cal_r_overlap_R.h"
45
#include "source_io/dipole_io.h"
56
#include "source_io/td_current_io.h"
67
#include "source_io/write_HS.h"
78
#include "source_io/write_HS_R.h"
8-
#include "source_estate/elecstate_tools.h"
99

1010
//--------------temporary----------------------------
1111
#include "source_base/blas_connector.h"
@@ -17,17 +17,17 @@
1717
#include "source_estate/module_dm/cal_edm_tddft.h"
1818
#include "source_estate/module_dm/density_matrix.h"
1919
#include "source_estate/occupy.h"
20+
#include "source_io/print_info.h"
2021
#include "source_lcao/module_rt/evolve_elec.h"
2122
#include "source_lcao/module_rt/td_velocity.h"
2223
#include "source_pw/module_pwdft/global.h"
23-
#include "source_io/print_info.h"
2424

2525
//-----HSolver ElecState Hamilt--------
26+
#include "module_parameter/parameter.h"
2627
#include "source_estate/cal_ux.h"
2728
#include "source_estate/elecstate_lcao.h"
28-
#include "source_lcao/hamilt_lcaodft/hamilt_lcao.h"
2929
#include "source_hsolver/hsolver_lcao.h"
30-
#include "module_parameter/parameter.h"
30+
#include "source_lcao/hamilt_lcaodft/hamilt_lcao.h"
3131
#include "source_psi/psi.h"
3232

3333
//-----force& stress-------------------
@@ -87,52 +87,52 @@ void ESolver_KS_LCAO_TDDFT<Device>::before_all_runners(UnitCell& ucell, const In
8787

8888
template <typename Device>
8989
void ESolver_KS_LCAO_TDDFT<Device>::hamilt2rho_single(UnitCell& ucell,
90-
const int istep,
91-
const int iter,
92-
const double ethr)
90+
const int istep,
91+
const int iter,
92+
const double ethr)
9393
{
9494
if (PARAM.inp.init_wfc == "file")
9595
{
9696
if (istep >= 1)
9797
{
9898
module_rt::Evolve_elec<Device>::solve_psi(istep,
99-
PARAM.inp.nbands,
100-
PARAM.globalv.nlocal,
101-
kv.get_nks(),
102-
this->p_hamilt,
103-
this->pv,
104-
this->psi,
105-
this->psi_laststep,
106-
this->Hk_laststep,
107-
this->Sk_laststep,
108-
this->pelec->ekb,
109-
GlobalV::ofs_running,
110-
td_htype,
111-
PARAM.inp.propagator,
112-
use_tensor,
113-
use_lapack);
99+
PARAM.inp.nbands,
100+
PARAM.globalv.nlocal,
101+
kv.get_nks(),
102+
this->p_hamilt,
103+
this->pv,
104+
this->psi,
105+
this->psi_laststep,
106+
this->Hk_laststep,
107+
this->Sk_laststep,
108+
this->pelec->ekb,
109+
GlobalV::ofs_running,
110+
td_htype,
111+
PARAM.inp.propagator,
112+
use_tensor,
113+
use_lapack);
114114
this->weight_dm_rho();
115115
}
116116
this->weight_dm_rho();
117117
}
118118
else if (istep >= 2)
119119
{
120120
module_rt::Evolve_elec<Device>::solve_psi(istep,
121-
PARAM.inp.nbands,
122-
PARAM.globalv.nlocal,
123-
kv.get_nks(),
124-
this->p_hamilt,
125-
this->pv,
126-
this->psi,
127-
this->psi_laststep,
128-
this->Hk_laststep,
129-
this->Sk_laststep,
130-
this->pelec->ekb,
131-
GlobalV::ofs_running,
132-
td_htype,
133-
PARAM.inp.propagator,
134-
use_tensor,
135-
use_lapack);
121+
PARAM.inp.nbands,
122+
PARAM.globalv.nlocal,
123+
kv.get_nks(),
124+
this->p_hamilt,
125+
this->pv,
126+
this->psi,
127+
this->psi_laststep,
128+
this->Hk_laststep,
129+
this->Sk_laststep,
130+
this->pelec->ekb,
131+
GlobalV::ofs_running,
132+
td_htype,
133+
PARAM.inp.propagator,
134+
use_tensor,
135+
use_lapack);
136136
this->weight_dm_rho();
137137
}
138138
else
@@ -163,18 +163,14 @@ void ESolver_KS_LCAO_TDDFT<Device>::hamilt2rho_single(UnitCell& ucell,
163163
}
164164

165165
template <typename Device>
166-
void ESolver_KS_LCAO_TDDFT<Device>::iter_finish(
167-
UnitCell& ucell,
168-
const int istep,
169-
int& iter,
170-
bool& conv_esolver)
166+
void ESolver_KS_LCAO_TDDFT<Device>::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver)
171167
{
172168
// print occupation of each band
173169
if (iter == 1 && istep <= 2)
174170
{
175-
GlobalV::ofs_running << " ---------------------------------------------------------"
176-
<< std::endl;
177-
GlobalV::ofs_running << " occupations of electrons" << std::endl;
171+
GlobalV::ofs_running << " ----------------------------------------------------------" << std::endl;
172+
GlobalV::ofs_running << " Occupations of electrons" << std::endl;
173+
GlobalV::ofs_running << " ----------------------------------------------------------" << std::endl;
178174
GlobalV::ofs_running << " k-point state occupation" << std::endl;
179175
GlobalV::ofs_running << std::setiosflags(std::ios::showpoint);
180176
GlobalV::ofs_running << std::left;
@@ -183,23 +179,21 @@ void ESolver_KS_LCAO_TDDFT<Device>::iter_finish(
183179
{
184180
for (int ib = 0; ib < PARAM.inp.nbands; ib++)
185181
{
186-
GlobalV::ofs_running << " " << std::setw(9)
187-
<< ik+1 << std::setw(8) << ib + 1
188-
<< std::setw(12) << this->pelec->wg(ik, ib) << std::endl;
182+
GlobalV::ofs_running << " " << std::setw(9) << ik + 1 << std::setw(8) << ib + 1 << std::setw(12)
183+
<< this->pelec->wg(ik, ib) << std::endl;
189184
}
190185
}
191-
GlobalV::ofs_running << " ---------------------------------------------------------"
192-
<< std::endl;
186+
GlobalV::ofs_running << " ----------------------------------------------------------" << std::endl;
193187
}
194188

195189
ESolver_KS_LCAO<std::complex<double>, double>::iter_finish(ucell, istep, iter, conv_esolver);
196190
}
197191

198192
template <typename Device>
199-
void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
200-
const int istep,
201-
const int iter,
202-
const bool conv_esolver)
193+
void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
194+
const int istep,
195+
const int iter,
196+
const bool conv_esolver)
203197
{
204198
// Calculate new potential according to new Charge Density
205199
if (!conv_esolver)
@@ -234,7 +228,6 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
234228
nrow_tmp = nlocal;
235229
#endif
236230
this->psi_laststep = new psi::Psi<std::complex<double>>(kv.get_nks(), ncol_tmp, nrow_tmp, kv.ngk, true);
237-
238231
}
239232

240233
// allocate memory for Hk_laststep and Sk_laststep
@@ -282,8 +275,8 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
282275
if (td_htype == 1)
283276
{
284277
this->p_hamilt->updateHk(ik);
285-
hamilt::MatrixBlock <std::complex<double>> h_mat;
286-
hamilt::MatrixBlock <std::complex<double>> s_mat;
278+
hamilt::MatrixBlock<std::complex<double>> h_mat;
279+
hamilt::MatrixBlock<std::complex<double>> s_mat;
287280
this->p_hamilt->matrix(h_mat, s_mat);
288281

289282
if (use_tensor && use_lapack)
@@ -323,31 +316,31 @@ void ESolver_KS_LCAO_TDDFT<Device>::update_pot(UnitCell& ucell,
323316
}
324317

325318
// print "eigen value" for tddft
326-
// it seems uncessary to print out E_ii because the band energies are printed
327-
/*
328-
if (conv_esolver)
329-
{
330-
GlobalV::ofs_running << "----------------------------------------------------------"
331-
<< std::endl;
332-
GlobalV::ofs_running << " Print E=<psi_i|H|psi_i> " << std::endl;
333-
GlobalV::ofs_running << " k-point state energy (eV)" << std::endl;
334-
GlobalV::ofs_running << "----------------------------------------------------------"
335-
<< std::endl;
336-
GlobalV::ofs_running << std::setprecision(6);
337-
GlobalV::ofs_running << std::setiosflags(std::ios::showpoint);
338-
339-
for (int ik = 0; ik < kv.get_nks(); ik++)
319+
// it seems unnecessary to print out E_ii because the band energies are printed
320+
/*
321+
if (conv_esolver)
340322
{
341-
for (int ib = 0; ib < PARAM.inp.nbands; ib++)
323+
GlobalV::ofs_running << "----------------------------------------------------------"
324+
<< std::endl;
325+
GlobalV::ofs_running << " Print E=<psi_i|H|psi_i> " << std::endl;
326+
GlobalV::ofs_running << " k-point state energy (eV)" << std::endl;
327+
GlobalV::ofs_running << "----------------------------------------------------------"
328+
<< std::endl;
329+
GlobalV::ofs_running << std::setprecision(6);
330+
GlobalV::ofs_running << std::setiosflags(std::ios::showpoint);
331+
332+
for (int ik = 0; ik < kv.get_nks(); ik++)
342333
{
343-
GlobalV::ofs_running << " " << std::setw(7) << ik + 1
344-
<< std::setw(7) << ib + 1
345-
<< std::setw(10) << this->pelec->ekb(ik, ib) * ModuleBase::Ry_to_eV
346-
<< std::endl;
334+
for (int ib = 0; ib < PARAM.inp.nbands; ib++)
335+
{
336+
GlobalV::ofs_running << " " << std::setw(7) << ik + 1
337+
<< std::setw(7) << ib + 1
338+
<< std::setw(10) << this->pelec->ekb(ik, ib) * ModuleBase::Ry_to_eV
339+
<< std::endl;
340+
}
347341
}
348342
}
349-
}
350-
*/
343+
*/
351344
}
352345

353346
template <typename Device>
@@ -365,16 +358,11 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
365358
{
366359
std::stringstream ss_dipole;
367360
ss_dipole << PARAM.globalv.global_out_dir << "SPIN" << is + 1 << "_DIPOLE";
368-
ModuleIO::write_dipole(ucell,
369-
this->chr.rho_save[is],
370-
this->chr.rhopw,
371-
is,
372-
istep,
373-
ss_dipole.str());
361+
ModuleIO::write_dipole(ucell, this->chr.rho_save[is], this->chr.rhopw, is, istep, ss_dipole.str());
374362
}
375363
}
376364

377-
// (2) write current information
365+
// (2) write current information
378366
if (TD_Velocity::out_current == true)
379367
{
380368
elecstate::DensityMatrix<std::complex<double>, double>* tmp_DM
@@ -392,7 +380,6 @@ void ESolver_KS_LCAO_TDDFT<Device>::after_scf(UnitCell& ucell, const int istep,
392380
this->RA);
393381
}
394382

395-
396383
ModuleBase::timer::tick("ESolver_LCAO_TDDFT", "after_scf");
397384
}
398385

@@ -410,7 +397,7 @@ void ESolver_KS_LCAO_TDDFT<Device>::weight_dm_rho()
410397
}
411398

412399
// calculate Eband energy
413-
elecstate::calEBand(this->pelec->ekb,this->pelec->wg,this->pelec->f_en);
400+
elecstate::calEBand(this->pelec->ekb, this->pelec->wg, this->pelec->f_en);
414401

415402
// calculate the density matrix
416403
ModuleBase::GlobalFunc::NOTE("Calculate the density matrix.");

source/source_estate/module_dm/cal_edm_tddft.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "source_base/scalapack_connector.h"
55
namespace elecstate
66
{
7-
87
// use the original formula (Hamiltonian matrix) to calculate energy density matrix
98
void cal_edm_tddft(Parallel_Orbitals& pv,
109
elecstate::ElecState* pelec,
@@ -254,5 +253,5 @@ void cal_edm_tddft(Parallel_Orbitals& pv,
254253
#endif
255254
}
256255
return;
257-
}
258-
} // namespace ModuleESolver
256+
} // cal_edm_tddft
257+
} // namespace elecstate

source/source_lcao/module_rt/evolve_elec.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
8383
propagator,
8484
ofs_running,
8585
print_matrix);
86-
// std::cout << "Print ekb: " << std::endl;
87-
// ekb.print(std::cout);
86+
// GlobalV::ofs_running << "Print ekb: " << std::endl;
87+
// ekb.print(GlobalV::ofs_running);
8888
}
8989
else
9090
{
@@ -118,7 +118,7 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
118118
#ifdef __MPI
119119
// Access the rank of the calling process in the communicator
120120
int myid = 0;
121-
int root_proc = 0;
121+
const int root_proc = 0;
122122
MPI_Comm_rank(MPI_COMM_WORLD, &myid);
123123

124124
// Gather psi to the root process
@@ -203,8 +203,16 @@ void Evolve_elec<Device>::solve_psi(const int& istep,
203203
len_HS_laststep);
204204
syncmem_double_d2h_op()(&(ekb(ik, 0)), ekb_tensor.data<double>(), nband);
205205

206-
// std::cout << "Print ekb tensor: " << std::endl;
207-
// ekb.print(std::cout);
206+
#ifdef __MPI
207+
const int root_proc = 0;
208+
if (use_lapack)
209+
{
210+
// Synchronize ekb to all MPI processes
211+
MPI_Bcast(&(ekb(ik, 0)), nband, MPI_DOUBLE, root_proc, MPI_COMM_WORLD);
212+
}
213+
#endif
214+
// GlobalV::ofs_running << "Print ekb: " << std::endl;
215+
// ekb.print(GlobalV::ofs_running);
208216
}
209217
}
210218
else

tests/15_rtTDDFT_GPU/CASES_GPU.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
04_NO_CO_ocp_TDDFT_GPU
55
05_NO_cur_TDDFT_GPU
66
06_NO_dir_TDDFT_GPU
7-
# 07_NO_EDM_TDDFT_GPU
7+
07_NO_EDM_TDDFT_GPU
88
09_NO_HEAV_TDDFT_GPU
99
10_NO_HHG_TDDFT_GPU
1010
11_NO_O3_TDDFT_GPU

0 commit comments

Comments
 (0)