Skip to content

Commit 704c671

Browse files
authored
Refactor: RT-TDDFT ESolver ESolver_KS_LCAO_TDDFT (#6668)
* Remove redundant includes in RT-TDDFT * Refactor store_h_s_psi * Refactor Hk and Sk with Tensor * Refactor MPI utility functions * Refactor gather and distribute Psi function * Modify the output suffix of some text files from .dat to .txt * Change dipole file name from SPIN*_DIPOLE to dipole_s*.txt * Move the output functions in after_scf to ctrl_output_td * Only calculate EDM in RT-TDDFT when the electronic step ends * Fix MPI bug * Fix LCAO macro bug
1 parent ab87e38 commit 704c671

File tree

46 files changed

+725
-621
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+725
-621
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3185,14 +3185,14 @@ These variables are used to control molecular dynamics calculations. For more in
31853185

31863186
- **Type**: Boolean
31873187
- **Description**: Control whether to restart molecular dynamics calculations and time-dependent density functional theory calculations.
3188-
- True: ABACUS will read in `${read_file_dir}/Restart_md.dat` to determine the current step `${md_step}`, then read in the corresponding `STRU_MD_${md_step}` in the folder `OUT.$suffix/STRU/` automatically. For tddft, ABACUS will also read in `WFC_NAO_K${kpoint}` of the last step (You need to set out_wfc_lcao=1 and out_app_flag=0 to obtain this file).
3188+
- True: ABACUS will read in `${read_file_dir}/Restart_md.txt` to determine the current step `${md_step}`, then read in the corresponding `STRU_MD_${md_step}` in the folder `OUT.$suffix/STRU/` automatically. For tddft, ABACUS will also read in `WFC_NAO_K${kpoint}` of the last step (You need to set out_wfc_lcao=1 and out_app_flag=0 to obtain this file).
31893189
- False: ABACUS will start molecular dynamics calculations normally from the first step.
31903190
- **Default**: False
31913191

31923192
### md_restartfreq
31933193

31943194
- **Type**: Integer
3195-
- **Description**: The output frequency of `OUT.${suffix}/Restart_md.dat` and structural files in the directory `OUT.${suffix}/STRIU/`, which are used to restart molecular dynamics calculations, see [md_restart](#md_restart) in detail.
3195+
- **Description**: The output frequency of `OUT.${suffix}/Restart_md.txt` and structural files in the directory `OUT.${suffix}/STRIU/`, which are used to restart molecular dynamics calculations, see [md_restart](#md_restart) in detail.
31963196
- **Default**: 5
31973197

31983198
### md_dumpfreq

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ OBJS_IO=input_conv.o\
560560
ctrl_iter_lcao.o\
561561
ctrl_output_fp.o\
562562
ctrl_output_pw.o\
563+
ctrl_output_td.o\
563564
para_json.o\
564565
abacusjson.o\
565566
general_info.o\

source/source_esolver/esolver_ks_lcao_tddft.cpp

Lines changed: 231 additions & 311 deletions
Large diffs are not rendered by default.

source/source_esolver/esolver_ks_lcao_tddft.h

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,52 +2,13 @@
22
#define ESOLVER_KS_LCAO_TDDFT_H
33
#include "esolver_ks.h"
44
#include "esolver_ks_lcao.h"
5-
#include "source_base/module_external/scalapack_connector.h" // Cpxgemr2d
6-
#include "source_lcao/record_adj.h"
7-
#include "source_psi/psi.h"
8-
#include "source_lcao/module_rt/velocity_op.h"
5+
#include "source_base/module_container/ATen/core/tensor.h" // ct::Tensor
6+
#include "source_lcao/module_rt/gather_mat.h" // MPI gathering and distributing functions
97
#include "source_lcao/module_rt/td_info.h"
8+
#include "source_lcao/module_rt/velocity_op.h"
109

1110
namespace ModuleESolver
1211
{
13-
//------------------------ MPI gathering and distributing functions ------------------------//
14-
// This struct is used for collecting matrices from all processes to root process
15-
template <typename T>
16-
struct Matrix_g
17-
{
18-
std::shared_ptr<T> p;
19-
size_t row;
20-
size_t col;
21-
std::shared_ptr<int> desc;
22-
};
23-
24-
// Collect matrices from all processes to root process
25-
template <typename T>
26-
void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock<T>& mat_l, Matrix_g<T>& mat_g)
27-
{
28-
const int* desca = mat_l.desc; // Obtain the descriptor of the local matrix
29-
int ctxt = desca[1]; // BLACS context
30-
int nrows = desca[2]; // Global matrix row number
31-
int ncols = desca[3]; // Global matrix column number
32-
33-
if (myid == root_proc)
34-
{
35-
mat_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
36-
}
37-
else
38-
{
39-
mat_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
40-
}
41-
42-
// Set the descriptor of the global matrix
43-
mat_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
44-
mat_g.row = nrows;
45-
mat_g.col = ncols;
46-
47-
// Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
48-
Cpxgemr2d(nrows, ncols, mat_l.p, 1, 1, const_cast<int*>(desca), mat_g.p.get(), 1, 1, mat_g.desc.get(), ctxt);
49-
}
50-
//------------------------ MPI gathering and distributing functions ------------------------//
5112

5213
template <typename TR, typename Device = base_device::DEVICE_CPU>
5314
class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>
@@ -64,29 +25,36 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>
6425

6526
virtual void hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) override;
6627

67-
// mohan change update_pot to save2, 2025-10-17
68-
void save2(UnitCell& ucell, const int istep, const int iter, const bool conv_esolver);
28+
void store_h_s_psi(UnitCell& ucell, const int istep, const int iter, const bool conv_esolver);
6929

70-
virtual void iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) override;
30+
void iter_finish(UnitCell& ucell,
31+
const int istep,
32+
const int estep,
33+
const int estep_max,
34+
int& iter,
35+
bool& conv_esolver);
7136

7237
virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override;
7338

7439
void print_step();
75-
//! wave functions of last time step
76-
psi::Psi<std::complex<double>>* psi_laststep = nullptr;
7740

78-
//! Hamiltonian of last time step
79-
std::complex<double>** Hk_laststep = nullptr;
41+
//! Wave function for all k-points of last time step
42+
psi::Psi<std::complex<double>>* psi_laststep = nullptr;
8043

81-
//! Overlap matrix of last time step
82-
std::complex<double>** Sk_laststep = nullptr;
44+
//! Hamiltonian for all k-points of last time step
45+
ct::Tensor Hk_laststep = ct::Tensor(ct::DataType::DT_COMPLEX_DOUBLE);
8346

84-
const int td_htype = 1;
47+
//! Overlap matrix for all k-points of last time step
48+
ct::Tensor Sk_laststep = ct::Tensor(ct::DataType::DT_COMPLEX_DOUBLE);
8549

8650
//! Control heterogeneous computing of the TDDFT solver
8751
bool use_tensor = false;
8852
bool use_lapack = false;
8953

54+
// Control the device type for Hk_laststep and Sk_laststep
55+
// Set to CPU temporarily, should wait for further GPU development
56+
static constexpr ct::DeviceType ct_device_type_hs = ct::DeviceType::CpuDevice;
57+
9058
//! Total steps for evolving the wave function
9159
int totstep = -1;
9260

@@ -95,13 +63,12 @@ class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>
9563

9664
TD_info* td_p = nullptr;
9765

98-
//! doubt
66+
//! Restart flag
9967
bool restart_done = false;
10068

10169
private:
10270
void weight_dm_rho(const UnitCell& ucell);
10371
};
10472

10573
} // namespace ModuleESolver
106-
#endif
107-
74+
#endif // ESOLVER_KS_LCAO_TDDFT_H

source/source_io/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ list(APPEND objects
22
input_conv.cpp
33
ctrl_output_fp.cpp
44
ctrl_output_pw.cpp
5+
ctrl_output_td.cpp
56
bessel_basis.cpp
67
cal_test.cpp
78
cal_dos.cpp
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#include "ctrl_output_td.h"
2+
3+
#include "source_base/parallel_global.h"
4+
#include "source_io/dipole_io.h"
5+
#include "source_io/module_parameter/parameter.h"
6+
#include "source_io/td_current_io.h"
7+
8+
namespace ModuleIO
9+
{
10+
11+
template <typename TR>
12+
void ctrl_output_td(const UnitCell& ucell,
13+
double** rho_save,
14+
const ModulePW::PW_Basis* rhopw,
15+
const int istep,
16+
const psi::Psi<std::complex<double>>* psi,
17+
const elecstate::ElecState* pelec,
18+
const K_Vectors& kv,
19+
const TwoCenterIntegrator* intor,
20+
const Parallel_Orbitals* pv,
21+
const LCAO_Orbitals& orb,
22+
const Velocity_op<TR>* velocity_mat,
23+
Record_adj& RA,
24+
TD_info* td_p)
25+
{
26+
ModuleBase::TITLE("ModuleIO", "ctrl_output_td");
27+
28+
// Original code commented out, might need reference later
29+
30+
// // (1) Write dipole information
31+
// for (int is = 0; is < PARAM.inp.nspin; is++)
32+
// {
33+
// if (PARAM.inp.out_dipole == 1)
34+
// {
35+
// std::stringstream ss_dipole;
36+
// ss_dipole << PARAM.globalv.global_out_dir << "dipole_s" << is + 1 << ".txt";
37+
// ModuleIO::write_dipole(ucell, this->chr.rho_save[is], this->chr.rhopw, is, istep, ss_dipole.str());
38+
// }
39+
// }
40+
41+
// // (2) Write current information
42+
// elecstate::DensityMatrix<std::complex<double>, double>* tmp_DM
43+
// = dynamic_cast<elecstate::ElecStateLCAO<std::complex<double>>*>(this->pelec)->get_DM();
44+
// if (TD_info::out_current)
45+
// {
46+
// if (TD_info::out_current_k)
47+
// {
48+
// ModuleIO::write_current_eachk(ucell,
49+
// istep,
50+
// this->psi,
51+
// this->pelec,
52+
// this->kv,
53+
// this->two_center_bundle_.overlap_orb.get(),
54+
// tmp_DM->get_paraV_pointer(),
55+
// this->orb_,
56+
// this->velocity_mat,
57+
// this->RA);
58+
// }
59+
// else
60+
// {
61+
// ModuleIO::write_current(ucell,
62+
// istep,
63+
// this->psi,
64+
// this->pelec,
65+
// this->kv,
66+
// this->two_center_bundle_.overlap_orb.get(),
67+
// tmp_DM->get_paraV_pointer(),
68+
// this->orb_,
69+
// this->velocity_mat,
70+
// this->RA);
71+
// }
72+
// }
73+
74+
// // (3) Output file for restart
75+
// if (PARAM.inp.out_freq_ion > 0) // default value of out_freq_ion is 0
76+
// {
77+
// if (istep % PARAM.inp.out_freq_ion == 0)
78+
// {
79+
// td_p->out_restart_info(istep, elecstate::H_TDDFT_pw::At, elecstate::H_TDDFT_pw::At_laststep);
80+
// }
81+
// }
82+
83+
#ifdef __LCAO
84+
// (1) Write dipole information
85+
for (int is = 0; is < PARAM.inp.nspin; ++is)
86+
{
87+
if (PARAM.inp.out_dipole == 1)
88+
{
89+
std::stringstream ss_dipole;
90+
ss_dipole << PARAM.globalv.global_out_dir << "dipole_s" << is + 1 << ".txt";
91+
ModuleIO::write_dipole(ucell, rho_save[is], rhopw, is, istep, ss_dipole.str());
92+
}
93+
}
94+
95+
// (2) Write current information
96+
const elecstate::ElecStateLCAO<std::complex<double>>* pelec_lcao
97+
= dynamic_cast<const elecstate::ElecStateLCAO<std::complex<double>>*>(pelec);
98+
99+
if (!pelec_lcao)
100+
{
101+
ModuleBase::WARNING_QUIT("ModuleIO::ctrl_output_td", "Failed to cast ElecState to ElecStateLCAO");
102+
}
103+
104+
if (TD_info::out_current)
105+
{
106+
if (TD_info::out_current_k)
107+
{
108+
ModuleIO::write_current_eachk<TR>(ucell, istep, psi, pelec, kv, intor, pv, orb, velocity_mat, RA);
109+
}
110+
else
111+
{
112+
ModuleIO::write_current<TR>(ucell, istep, psi, pelec, kv, intor, pv, orb, velocity_mat, RA);
113+
}
114+
}
115+
116+
// (3) Output file for restart
117+
if (PARAM.inp.out_freq_ion > 0) // default value of out_freq_ion is 0
118+
{
119+
if (istep % PARAM.inp.out_freq_ion == 0)
120+
{
121+
if (td_p != nullptr)
122+
{
123+
td_p->out_restart_info(istep, elecstate::H_TDDFT_pw::At, elecstate::H_TDDFT_pw::At_laststep);
124+
}
125+
else
126+
{
127+
ModuleBase::WARNING_QUIT("ModuleIO::ctrl_output_td",
128+
"TD_info pointer is null, cannot output restart info.");
129+
}
130+
}
131+
}
132+
#endif // __LCAO
133+
}
134+
135+
template void ctrl_output_td<double>(const UnitCell&,
136+
double**,
137+
const ModulePW::PW_Basis*,
138+
const int,
139+
const psi::Psi<std::complex<double>>*,
140+
const elecstate::ElecState*,
141+
const K_Vectors&,
142+
const TwoCenterIntegrator*,
143+
const Parallel_Orbitals*,
144+
const LCAO_Orbitals&,
145+
const Velocity_op<double>*,
146+
Record_adj&,
147+
TD_info*);
148+
149+
template void ctrl_output_td<std::complex<double>>(const UnitCell&,
150+
double**,
151+
const ModulePW::PW_Basis*,
152+
const int,
153+
const psi::Psi<std::complex<double>>*,
154+
const elecstate::ElecState*,
155+
const K_Vectors&,
156+
const TwoCenterIntegrator*,
157+
const Parallel_Orbitals*,
158+
const LCAO_Orbitals&,
159+
const Velocity_op<std::complex<double>>*,
160+
Record_adj&,
161+
TD_info*);
162+
163+
} // namespace ModuleIO

source/source_io/ctrl_output_td.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef CTRL_OUTPUT_TD_H
2+
#define CTRL_OUTPUT_TD_H
3+
4+
#include "source_basis/module_ao/ORB_read.h"
5+
#include "source_basis/module_ao/parallel_orbitals.h"
6+
#include "source_basis/module_nao/two_center_bundle.h"
7+
#include "source_cell/unitcell.h"
8+
#include "source_estate/elecstate_lcao.h"
9+
#include "source_estate/module_pot/H_TDDFT_pw.h"
10+
#include "source_lcao/module_rt/td_info.h"
11+
#include "source_lcao/module_rt/velocity_op.h"
12+
#include "source_lcao/record_adj.h"
13+
#include "source_psi/psi.h"
14+
15+
namespace ModuleIO
16+
{
17+
18+
template <typename TR>
19+
void ctrl_output_td(const UnitCell& ucell,
20+
double** rho_save,
21+
const ModulePW::PW_Basis* rhopw,
22+
const int istep,
23+
const psi::Psi<std::complex<double>>* psi,
24+
const elecstate::ElecState* pelec,
25+
const K_Vectors& kv,
26+
const TwoCenterIntegrator* intor,
27+
const Parallel_Orbitals* pv,
28+
const LCAO_Orbitals& orb,
29+
const Velocity_op<TR>* velocity_mat,
30+
Record_adj& RA,
31+
TD_info* td_p);
32+
33+
} // namespace ModuleIO
34+
35+
#endif // CTRL_OUTPUT_TD_H

source/source_io/read_input.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,12 @@ void ReadInput::check_ntype(const std::string& fn, int& param_ntype)
488488
int ReadInput::current_md_step(const std::string& file_dir)
489489
{
490490
std::stringstream ssc;
491-
ssc << file_dir << "Restart_md.dat";
491+
ssc << file_dir << "Restart_md.txt";
492492
std::ifstream file(ssc.str().c_str());
493493

494494
if (!file)
495495
{
496-
ModuleBase::WARNING_QUIT("current_md_step", "no Restart_md.dat");
496+
ModuleBase::WARNING_QUIT("current_md_step", "no Restart_md.txt");
497497
}
498498

499499
int md_step;

source/source_io/read_input.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class ReadInput
6868
/**
6969
* @brief determine the md step in restart case
7070
*
71-
* @param file_dir directory of Restart_md.dat
71+
* @param file_dir directory of Restart_md.txt
7272
* @return md step
7373
*/
7474
int current_md_step(const std::string& file_dir);

0 commit comments

Comments
 (0)