Skip to content

Commit 42042b4

Browse files
linpeizePeizeLin
andauthored
add exx nscf file check (#6288)
Co-authored-by: linpz <linpz@mail.ustc.edu.cn>
1 parent 7155c3f commit 42042b4

File tree

4 files changed

+117
-80
lines changed

4 files changed

+117
-80
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.hpp

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ template <typename TK, typename TR>
8888
OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
8989
HContainer<TR>*hR_in,
9090
const UnitCell& ucell_in,
91-
const K_Vectors& kv_in,
92-
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
93-
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
91+
const K_Vectors& kv_in,
92+
std::vector<std::map<int, std::map<TAC, RI::Tensor<double>>>>* Hexxd_in,
93+
std::vector<std::map<int, std::map<TAC, RI::Tensor<std::complex<double>>>>>* Hexxc_in,
9494
Add_Hexx_Type add_hexx_type_in,
9595
const int istep,
9696
int* two_level_step_in,
97-
const bool restart_in)
97+
const bool restart_in)
9898
: OperatorLCAO<TK, TR>(hsk_in, kv_in.kvec_d, hR_in),
9999
ucell(ucell_in),
100100
kv(kv_in),
@@ -111,47 +111,75 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
111111

112112
if (PARAM.inp.calculation == "nscf" && GlobalC::exx_info.info_global.cal_exx)
113113
{ // if nscf, read HexxR first and reallocate hR according to the read-in HexxR
114-
const std::string file_name_exx = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(GlobalV::MY_RANK);
115-
bool all_exist = true;
116-
for (int is=0;is<PARAM.inp.nspin;++is)
114+
auto file_name_list_csr = []() -> std::vector<std::string>
117115
{
118-
std::ifstream ifs(file_name_exx + "_" + std::to_string(is) + ".csr");
119-
if (!ifs) { all_exist = false; break; }
120-
}
121-
if (all_exist)
116+
std::vector<std::string> file_name_list;
117+
for (int irank=0; irank<PARAM.globalv.nproc; ++irank) {
118+
for (int is=0;is<PARAM.inp.nspin;++is) {
119+
file_name_list.push_back( PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(irank) + "_" + std::to_string(is) + ".csr" );
120+
} }
121+
return file_name_list;
122+
};
123+
auto file_name_list_cereal = []() -> std::vector<std::string>
124+
{
125+
std::vector<std::string> file_name_list;
126+
for (int irank=0; irank<PARAM.globalv.nproc; ++irank)
127+
{ file_name_list.push_back( "HexxR_" + std::to_string(irank) ); }
128+
return file_name_list;
129+
};
130+
auto check_exist = [](const std::vector<std::string> &file_name_list) -> bool
131+
{
132+
for (const std::string &file_name : file_name_list)
133+
{
134+
std::ifstream ifs(file_name);
135+
if (!ifs.is_open())
136+
{ return false; }
137+
}
138+
return true;
139+
};
140+
141+
std::cout<<" Attention: The number of MPI processes must be strictly identical between SCF and NSCF when computing exact-exchange."<<std::endl;
142+
if (check_exist(file_name_list_csr()))
122143
{
144+
const std::string file_name_exx_csr = PARAM.globalv.global_readin_dir + "HexxR" + std::to_string(PARAM.globalv.myrank);
123145
// Read HexxR in CSR format
124146
if (GlobalC::exx_info.info_ri.real_number)
125147
{
126-
ModuleIO::read_Hexxs_csr(file_name_exx, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
127-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
148+
ModuleIO::read_Hexxs_csr(file_name_exx_csr, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxd);
149+
if (this->add_hexx_type == Add_Hexx_Type::R)
150+
{ reallocate_hcontainer(*Hexxd, this->hR); }
128151
}
129152
else
130153
{
131-
ModuleIO::read_Hexxs_csr(file_name_exx, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
132-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
154+
ModuleIO::read_Hexxs_csr(file_name_exx_csr, ucell, PARAM.inp.nspin, PARAM.globalv.nlocal, *Hexxc);
155+
if (this->add_hexx_type == Add_Hexx_Type::R)
156+
{ reallocate_hcontainer(*Hexxc, this->hR); }
133157
}
134158
}
135-
else
159+
else if (check_exist(file_name_list_cereal()))
136160
{
137161
// Read HexxR in binary format (old version)
138-
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(GlobalV::MY_RANK);
162+
const std::string file_name_exx_cereal = PARAM.globalv.global_readin_dir + "HexxR_" + std::to_string(PARAM.globalv.myrank);
139163
std::ifstream ifs(file_name_exx_cereal, std::ios::binary);
140164
if (!ifs)
141-
{
142-
ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file < " + file_name_exx_cereal + " >.");
143-
}
165+
{ ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file < " + file_name_exx_cereal + " >."); }
144166
if (GlobalC::exx_info.info_ri.real_number)
145167
{
146168
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxd);
147-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxd, this->hR); }
169+
if (this->add_hexx_type == Add_Hexx_Type::R)
170+
{ reallocate_hcontainer(*Hexxd, this->hR); }
148171
}
149172
else
150173
{
151174
ModuleIO::read_Hexxs_cereal(file_name_exx_cereal, *Hexxc);
152-
if (this->add_hexx_type == Add_Hexx_Type::R) { reallocate_hcontainer(*Hexxc, this->hR); }
175+
if (this->add_hexx_type == Add_Hexx_Type::R)
176+
{ reallocate_hcontainer(*Hexxc, this->hR); }
153177
}
154178
}
179+
else
180+
{
181+
ModuleBase::WARNING_QUIT("OperatorEXX", "Can't open EXX file in " + PARAM.globalv.global_readin_dir);
182+
}
155183
this->use_cell_nearest = false;
156184
}
157185
else
@@ -207,7 +235,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
207235
else if (this->add_hexx_type == Add_Hexx_Type::R)
208236
{
209237
// read in Hexx(R)
210-
const std::string restart_HR_path = GlobalC::restart.folder + "HexxR" + std::to_string(GlobalV::MY_RANK);
238+
const std::string restart_HR_path = GlobalC::restart.folder + "HexxR" + std::to_string(PARAM.globalv.myrank);
211239
int all_exist = 1;
212240
for (int is = 0; is < PARAM.inp.nspin; ++is)
213241
{
@@ -232,7 +260,7 @@ OperatorEXX<OperatorLCAO<TK, TR>>::OperatorEXX(HS_Matrix_K<TK>* hsk_in,
232260
else
233261
{
234262
// Read HexxR in binary format (old version)
235-
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(GlobalV::MY_RANK);
263+
const std::string restart_HR_path_cereal = GlobalC::restart.folder + "HexxR_" + std::to_string(PARAM.globalv.myrank);
236264
std::ifstream ifs(restart_HR_path_cereal, std::ios::binary);
237265
int all_exist_cereal = ifs ? 1 : 0;
238266
#ifdef __MPI

source/module_io/restart_exx_csr.hpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ namespace ModuleIO
2828
{
2929
const std::vector<int>& R = csr.getRCoordinate(iR);
3030
TC dR({ R[0], R[1], R[2] });
31-
Hexxs[is][iat1][{iat2, dR}] = RI::Tensor<Tdata>(
32-
{
33-
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw),
34-
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw)
35-
}
36-
);
31+
Hexxs[is][iat1][{iat2, dR}] = RI::Tensor<Tdata>(
32+
{
33+
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat1]].nw),
34+
static_cast<size_t>(ucell.atoms[ucell.iat2it[iat2]].nw)
35+
}
36+
);
3737
}
3838
}
3939
}
@@ -49,12 +49,12 @@ namespace ModuleIO
4949
const int& npol = ucell.get_npol();
5050
const int& i = ijv.first.first * npol;
5151
const int& j = ijv.first.second * npol;
52-
Hexxs.at(is).at(ucell.iwt2iat[i]).at(
53-
{
54-
ucell.iwt2iat[j],
55-
{ R[0], R[1], R[2] }
56-
}
57-
)(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
52+
Hexxs.at(is).at(ucell.iwt2iat[i]).at(
53+
{
54+
ucell.iwt2iat[j],
55+
{ R[0], R[1], R[2] }
56+
}
57+
)(ucell.iwt2iw[i] / npol, ucell.iwt2iw[j] / npol) = ijv.second;
5858
}
5959
}
6060
}
@@ -67,6 +67,8 @@ namespace ModuleIO
6767
ModuleBase::TITLE("ModuleIO", "read_Hexxs_cereal");
6868
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");
6969
std::ifstream ifs(file_name, std::ios::binary);
70+
if(!ifs.is_open())
71+
{ ModuleBase::WARNING_QUIT("read_Hexxs_cereal", file_name+" not found."); }
7072
cereal::BinaryInputArchive iar(ifs);
7173
iar(Hexxs);
7274
ModuleBase::timer::tick("Exx_LRI", "read_Hexxs_cereal");

source/module_ri/Exx_LRI_interface.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ class Exx_LRI_Interface
4242
}
4343
Exx_LRI_Interface() = delete;
4444

45-
/// read and write Hexxs using cereal
46-
void write_Hexxs_cereal(const std::string& file_name) const;
47-
void read_Hexxs_cereal(const std::string& file_name);
45+
///// read and write Hexxs using cereal
46+
//void write_Hexxs_cereal(const std::string& file_name) const;
47+
//void read_Hexxs_cereal(const std::string& file_name);
4848

4949
std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& get_Hexxs() const { return this->exx_ptr->Hexxs; }
5050
double &get_Eexx() const { return this->exx_ptr->Eexx; }

source/module_ri/Exx_LRI_interface.hpp

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <stdexcept>
1919
#include <string>
2020

21+
/*
2122
template<typename T, typename Tdata>
2223
void Exx_LRI_Interface<T, Tdata>::write_Hexxs_cereal(const std::string& file_name) const
2324
{
@@ -34,11 +35,15 @@ void Exx_LRI_Interface<T, Tdata>::read_Hexxs_cereal(const std::string& file_name
3435
{
3536
ModuleBase::TITLE("Exx_LRI_Interface", "read_Hexxs_cereal");
3637
ModuleBase::timer::tick("Exx_LRI_Interface", "read_Hexxs_cereal");
37-
std::ifstream ifs(file_name + "_" + std::to_string(GlobalV::MY_RANK), std::ofstream::binary);
38+
const std::string file_name_rank = file_name + "_" + std::to_string(GlobalV::MY_RANK);
39+
std::ifstream ifs(file_name_rank, std::ofstream::binary);
40+
if(!ifs.is_open())
41+
{ ModuleBase::WARNING_QUIT("Exx_LRI_Interface", file_name_rank+" not found."); }
3842
cereal::BinaryInputArchive iar(ifs);
3943
iar(this->exx_ptr->Hexxs);
4044
ModuleBase::timer::tick("Exx_LRI_Interface", "read_Hexxs_cereal");
4145
}
46+
*/
4247

4348
template<typename T, typename Tdata>
4449
void Exx_LRI_Interface<T, Tdata>::init(const MPI_Comm &mpi_comm,
@@ -70,11 +75,11 @@ void Exx_LRI_Interface<T, Tdata>::cal_exx_elec(const std::vector<std::map<TA, st
7075
const ModuleSymmetry::Symmetry_rotation* p_symrot)
7176
{
7277
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_elec");
73-
if(!this->flag_finish.init || !this->flag_finish.ions)
74-
{
75-
throw std::runtime_error("Exx init unfinished when "
78+
if(!this->flag_finish.init || !this->flag_finish.ions)
79+
{
80+
throw std::runtime_error("Exx init unfinished when "
7681
+std::string(__FILE__)+" line "+std::to_string(__LINE__));
77-
}
82+
}
7883

7984
this->exx_ptr->cal_exx_elec(Ds, ucell, pv, p_symrot);
8085

@@ -85,15 +90,15 @@ template<typename T, typename Tdata>
8590
void Exx_LRI_Interface<T, Tdata>::cal_exx_force(const int& nat)
8691
{
8792
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_force");
88-
if(!this->flag_finish.init || !this->flag_finish.ions)
89-
{
90-
throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
91-
}
92-
if(!this->flag_finish.elec)
93-
{
94-
throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)
93+
if(!this->flag_finish.init || !this->flag_finish.ions)
94+
{
95+
throw std::runtime_error("Exx init unfinished when "+std::string(__FILE__)+" line "+std::to_string(__LINE__));
96+
}
97+
if(!this->flag_finish.elec)
98+
{
99+
throw std::runtime_error("Exx Hamiltonian unfinished when "+std::string(__FILE__)
95100
+" line "+std::to_string(__LINE__));
96-
}
101+
}
97102

98103
this->exx_ptr->cal_exx_force(nat);
99104

@@ -104,33 +109,36 @@ template<typename T, typename Tdata>
104109
void Exx_LRI_Interface<T, Tdata>::cal_exx_stress(const double& omega, const double& lat0)
105110
{
106111
ModuleBase::TITLE("Exx_LRI_Interface","cal_exx_stress");
107-
if(!this->flag_finish.init || !this->flag_finish.ions)
108-
{
109-
throw std::runtime_error("Exx init unfinished when "
110-
+std::string(__FILE__)+" line "+std::to_string(__LINE__));
111-
}
112-
if(!this->flag_finish.elec)
113-
{
114-
throw std::runtime_error("Exx Hamiltonian unfinished when "
115-
+std::string(__FILE__)+" line "+std::to_string(__LINE__));
116-
}
112+
if(!this->flag_finish.init || !this->flag_finish.ions)
113+
{
114+
throw std::runtime_error("Exx init unfinished when "
115+
+std::string(__FILE__)+" line "+std::to_string(__LINE__));
116+
}
117+
if(!this->flag_finish.elec)
118+
{
119+
throw std::runtime_error("Exx Hamiltonian unfinished when "
120+
+std::string(__FILE__)+" line "+std::to_string(__LINE__));
121+
}
117122

118123
this->exx_ptr->cal_exx_stress(omega, lat0);
119124

120125
this->flag_finish.stress = true;
121126
}
122127

123128
template<typename T, typename Tdata>
124-
void Exx_LRI_Interface<T, Tdata>::exx_before_all_runners(const K_Vectors& kv,
125-
const UnitCell& ucell, const Parallel_2D& pv)
129+
void Exx_LRI_Interface<T, Tdata>::exx_before_all_runners(
130+
const K_Vectors& kv,
131+
const UnitCell& ucell,
132+
const Parallel_2D& pv)
126133
{
127134
ModuleBase::TITLE("Exx_LRI_Interface","exx_before_all_runners");
128135
// initialize the rotation matrix in AO representation
129136
this->exx_spacegroup_symmetry = (PARAM.inp.nspin < 4 && ModuleSymmetry::Symmetry::symm_flag == 1);
130137
if (this->exx_spacegroup_symmetry)
131138
{
132139
const std::array<int, 3>& period = RI_Util::get_Born_vonKarmen_period(kv);
133-
this->symrot_.find_irreducible_sector(ucell.symm, ucell.atoms, ucell.st,
140+
this->symrot_.find_irreducible_sector(
141+
ucell.symm, ucell.atoms, ucell.st,
134142
RI_Util::get_Born_von_Karmen_cells(period), period, ucell.lat);
135143
// this->symrot_.set_Cs_rotation(this->exx_ptr->get_abfs_nchis());
136144
this->symrot_.cal_Ms(kv, ucell, pv);
@@ -233,18 +241,18 @@ void Exx_LRI_Interface<T, Tdata>::exx_eachiterinit(const int istep,
233241
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>
234242
Ds = PARAM.globalv.gamma_only_local
235243
? RI_2D_Comm::split_m2D_ktoR<Tdata>(
236-
ucell,
237-
*this->exx_ptr->p_kv,
238-
this->mix_DMk_2D.get_DMk_gamma_out(),
239-
*dm_in.get_paraV_pointer(),
240-
PARAM.inp.nspin)
244+
ucell,
245+
*this->exx_ptr->p_kv,
246+
this->mix_DMk_2D.get_DMk_gamma_out(),
247+
*dm_in.get_paraV_pointer(),
248+
PARAM.inp.nspin)
241249
: RI_2D_Comm::split_m2D_ktoR<Tdata>(
242-
ucell,
243-
*this->exx_ptr->p_kv,
244-
this->mix_DMk_2D.get_DMk_k_out(),
245-
*dm_in.get_paraV_pointer(),
246-
PARAM.inp.nspin,
247-
this->exx_spacegroup_symmetry);
250+
ucell,
251+
*this->exx_ptr->p_kv,
252+
this->mix_DMk_2D.get_DMk_k_out(),
253+
*dm_in.get_paraV_pointer(),
254+
PARAM.inp.nspin,
255+
this->exx_spacegroup_symmetry);
248256

249257
if (this->exx_spacegroup_symmetry && GlobalC::exx_info.info_global.exx_symmetry_realspace)
250258
{ this->cal_exx_elec(Ds, ucell,*dm_in.get_paraV_pointer(), &this->symrot_); }
@@ -274,11 +282,10 @@ void Exx_LRI_Interface<T, Tdata>::exx_hamilt2rho(elecstate::ElecState& elec, con
274282
{
275283
if (GlobalV::MY_RANK == 0)
276284
{
277-
try { GlobalC::restart.load_disk("Eexx", 0, 1, &this->exx_ptr->Eexx); }
285+
try
286+
{ GlobalC::restart.load_disk("Eexx", 0, 1, &this->exx_ptr->Eexx); }
278287
catch (const std::exception& e)
279-
{
280-
std::cout << "WARNING: Cannot read Eexx from disk, the energy of the 1st loop will be wrong, sbut it does not influence the subsequent loops." << std::endl;
281-
}
288+
{ std::cout << "WARNING: Cannot read Eexx from disk, the energy of the 1st loop will be wrong, sbut it does not influence the subsequent loops." << std::endl; }
282289
}
283290
Parallel_Common::bcast_double(this->exx_ptr->Eexx);
284291
this->exx_ptr->Eexx /= GlobalC::exx_info.info_global.hybrid_alpha;

0 commit comments

Comments
 (0)