Skip to content

Commit 9d78d90

Browse files
linpeizePeizeLin
andauthored
add check and update code format in exx (#6244) (#6255)
* fix bug and update code format in exx * Fix bug in Exx_LRI_Interface. Change && to || * update exx in ESolver_KS_LCAO and FORCE_STRESS * update runtime check in Exx_LRI_Interface * move exx_lri_double from ESolver_KS_LCAO to Exx_LRI_Interface --------- Conflicts: source/module_ri/Exx_LRI_interface.hpp Co-authored-by: linpz <linpz@mail.ustc.edu.cn>
1 parent 9f3e6e7 commit 9d78d90

File tree

12 files changed

+345
-287
lines changed

12 files changed

+345
-287
lines changed

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,11 @@ ESolver_KS_LCAO<TK, TR>::ESolver_KS_LCAO()
8383
// because some members like two_level_step are used outside if(cal_exx)
8484
if (GlobalC::exx_info.info_ri.real_number)
8585
{
86-
this->exx_lri_double = std::make_shared<Exx_LRI<double>>(GlobalC::exx_info.info_ri);
87-
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(exx_lri_double);
86+
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(GlobalC::exx_info.info_ri);
8887
}
8988
else
9089
{
91-
this->exx_lri_complex = std::make_shared<Exx_LRI<std::complex<double>>>(GlobalC::exx_info.info_ri);
92-
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(exx_lri_complex);
90+
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(GlobalC::exx_info.info_ri);
9391
}
9492
#endif
9593
}
@@ -198,12 +196,12 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
198196
// initialize 2-center radial tables for EXX-LRI
199197
if (GlobalC::exx_info.info_ri.real_number)
200198
{
201-
this->exx_lri_double->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
199+
this->exd->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
202200
this->exd->exx_before_all_runners(this->kv, ucell, this->pv);
203201
}
204202
else
205203
{
206-
this->exx_lri_complex->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
204+
this->exc->init(MPI_COMM_WORLD, ucell, this->kv, orb_);
207205
this->exc->exx_before_all_runners(this->kv, ucell, this->pv);
208206
}
209207
}
@@ -351,8 +349,8 @@ void ESolver_KS_LCAO<TK, TR>::cal_force(UnitCell& ucell, ModuleBase::matrix& for
351349
this->ld,
352350
#endif
353351
#ifdef __EXX
354-
*this->exx_lri_double,
355-
*this->exx_lri_complex,
352+
*this->exd,
353+
*this->exc,
356354
#endif
357355
&ucell.symm);
358356

@@ -461,8 +459,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
461459
this->gd
462460
#ifdef __EXX
463461
,
464-
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
465-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
462+
this->exd ? &this->exd->get_Hexxs() : nullptr,
463+
this->exc ? &this->exc->get_Hexxs() : nullptr
466464
#endif
467465
);
468466
}
@@ -484,8 +482,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
484482
this->gd
485483
#ifdef __EXX
486484
,
487-
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
488-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
485+
this->exd ? &this->exd->get_Hexxs() : nullptr,
486+
this->exc ? &this->exc->get_Hexxs() : nullptr
489487
#endif
490488
);
491489
}
@@ -514,8 +512,8 @@ void ESolver_KS_LCAO<TK, TR>::after_all_runners(UnitCell& ucell)
514512
this->two_center_bundle_
515513
#ifdef __EXX
516514
,
517-
this->exx_lri_double ? &this->exx_lri_double->Hexxs : nullptr,
518-
this->exx_lri_complex ? &this->exx_lri_complex->Hexxs : nullptr
515+
this->exd ? &this->exd->get_Hexxs() : nullptr,
516+
this->exc ? &this->exc->get_Hexxs() : nullptr
519517
#endif
520518
);
521519
}

source/module_esolver/esolver_ks_lcao.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
116116
#ifdef __EXX
117117
std::shared_ptr<Exx_LRI_Interface<TK, double>> exd = nullptr;
118118
std::shared_ptr<Exx_LRI_Interface<TK, std::complex<double>>> exc = nullptr;
119-
std::shared_ptr<Exx_LRI<double>> exx_lri_double = nullptr;
120-
std::shared_ptr<Exx_LRI<std::complex<double>>> exx_lri_complex = nullptr;
121119
#endif
122120

123121
friend class LR::ESolver_LR<double, double>;

source/module_esolver/lcao_before_scf.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
159159
,
160160
istep,
161161
GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step,
162-
GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr,
163-
GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs
162+
GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr,
163+
GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs()
164164
#endif
165165
);
166166
}

source/module_esolver/lcao_others.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
217217
,
218218
istep,
219219
GlobalC::exx_info.info_ri.real_number ? &this->exd->two_level_step : &this->exc->two_level_step,
220-
GlobalC::exx_info.info_ri.real_number ? &exx_lri_double->Hexxs : nullptr,
221-
GlobalC::exx_info.info_ri.real_number ? nullptr : &exx_lri_complex->Hexxs
220+
GlobalC::exx_info.info_ri.real_number ? &this->exd->get_Hexxs() : nullptr,
221+
GlobalC::exx_info.info_ri.real_number ? nullptr : &this->exc->get_Hexxs()
222222
#endif
223223
);
224224
}

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
5555
LCAO_Deepks<T>& ld,
5656
#endif
5757
#ifdef __EXX
58-
Exx_LRI<double>& exx_lri_double,
59-
Exx_LRI<std::complex<double>>& exx_lri_complex,
58+
Exx_LRI_Interface<T, double>& exd,
59+
Exx_LRI_Interface<T, std::complex<double>>& exc,
6060
#endif
6161
ModuleSymmetry::Symmetry* symm)
6262
{
@@ -377,26 +377,26 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
377377
{
378378
if (GlobalC::exx_info.info_ri.real_number)
379379
{
380-
exx_lri_double.cal_exx_force(ucell.nat);
381-
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.force_exx;
380+
exd.cal_exx_force(ucell.nat);
381+
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_force();
382382
}
383383
else
384384
{
385-
exx_lri_complex.cal_exx_force(ucell.nat);
386-
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.force_exx;
385+
exc.cal_exx_force(ucell.nat);
386+
force_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_force();
387387
}
388388
}
389389
if (isstress)
390390
{
391391
if (GlobalC::exx_info.info_ri.real_number)
392392
{
393-
exx_lri_double.cal_exx_stress(ucell.omega, ucell.lat0);
394-
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_double.stress_exx;
393+
exd.cal_exx_stress(ucell.omega, ucell.lat0);
394+
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exd.get_stress();
395395
}
396396
else
397397
{
398-
exx_lri_complex.cal_exx_stress(ucell.omega, ucell.lat0);
399-
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exx_lri_complex.stress_exx;
398+
exc.cal_exx_stress(ucell.omega, ucell.lat0);
399+
stress_exx = GlobalC::exx_info.info_global.hybrid_alpha * exc.get_stress();
400400
}
401401
}
402402
}

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "module_io/input_conv.h"
1212
#include "module_psi/psi.h"
1313
#ifdef __EXX
14-
#include "module_ri/Exx_LRI.h"
14+
#include "module_ri/Exx_LRI_interface.h"
1515
#endif
1616
#include "force_stress_arrays.h"
1717
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
@@ -53,8 +53,8 @@ class Force_Stress_LCAO
5353
LCAO_Deepks<T>& ld,
5454
#endif
5555
#ifdef __EXX
56-
Exx_LRI<double>& exx_lri_double,
57-
Exx_LRI<std::complex<double>>& exx_lri_complex,
56+
Exx_LRI_Interface<T, double>& exd,
57+
Exx_LRI_Interface<T, std::complex<double>>& exc,
5858
#endif
5959
ModuleSymmetry::Symmetry* symm);
6060

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,10 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
272272
{
273273
// if the same kernel is calculated in the esolver_ks, move it
274274
std::string dft_functional = LR_Util::tolower(input.dft_functional);
275-
if (ks_sol.exx_lri_double && std::is_same<T, double>::value && xc_kernel == dft_functional) {
276-
this->move_exx_lri(ks_sol.exx_lri_double);
277-
} else if (ks_sol.exx_lri_complex && std::is_same<T, std::complex<double>>::value && xc_kernel == dft_functional) {
278-
this->move_exx_lri(ks_sol.exx_lri_complex);
275+
if (ks_sol.exd && std::is_same<T, double>::value && xc_kernel == dft_functional) {
276+
this->move_exx_lri(ks_sol.exd->exx_ptr);
277+
} else if (ks_sol.exc && std::is_same<T, std::complex<double>>::value && xc_kernel == dft_functional) {
278+
this->move_exx_lri(ks_sol.exc->exx_ptr);
279279
} else // construct C, V from scratch
280280
{
281281
// set ccp_type according to the xc_kernel

source/module_ri/Exx_LRI.h

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,21 @@
2121
#include "module_exx_symmetry/symmetry_rotation.h"
2222

2323
class Parallel_Orbitals;
24-
24+
2525
template<typename T, typename Tdata>
2626
class RPA_LRI;
2727

2828
template<typename T, typename Tdata>
2929
class Exx_LRI_Interface;
3030

31-
namespace LR
32-
{
33-
template<typename T, typename TR>
34-
class ESolver_LR;
31+
namespace LR
32+
{
33+
template<typename T, typename TR>
34+
class ESolver_LR;
3535

36-
template<typename T>
37-
class OperatorLREXX;
38-
}
36+
template<typename T>
37+
class OperatorLREXX;
38+
}
3939

4040
template<typename Tdata>
4141
class Exx_LRI
@@ -49,37 +49,39 @@ class Exx_LRI
4949
using TatomR = std::array<double,Ndim>; // tmp
5050

5151
public:
52-
Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {}
53-
Exx_LRI operator=(const Exx_LRI&) = delete;
54-
Exx_LRI operator=(Exx_LRI&&);
55-
56-
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
57-
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
58-
59-
void init(const MPI_Comm &mpi_comm_in,
60-
const UnitCell &ucell,
61-
const K_Vectors &kv_in,
62-
const LCAO_Orbitals& orb);
63-
void cal_exx_force(const int& nat);
64-
void cal_exx_stress(const double& omega, const double& lat0);
52+
Exx_LRI(const Exx_Info::Exx_Info_RI& info_in) :info(info_in) {}
53+
Exx_LRI operator=(const Exx_LRI&) = delete;
54+
Exx_LRI operator=(Exx_LRI&&);
55+
56+
void init(
57+
const MPI_Comm &mpi_comm_in,
58+
const UnitCell &ucell,
59+
const K_Vectors &kv_in,
60+
const LCAO_Orbitals& orb);
6561
void cal_exx_ions(const UnitCell& ucell, const bool write_cv = false);
66-
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
62+
void cal_exx_elec(
63+
const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
6764
const UnitCell& ucell,
68-
const Parallel_Orbitals& pv,
69-
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
70-
std::vector<std::vector<int>> get_abfs_nchis() const;
65+
const Parallel_Orbitals& pv,
66+
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
67+
void cal_exx_force(const int& nat);
68+
void cal_exx_stress(const double& omega, const double& lat0);
69+
70+
void reset_Cs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Cs_in) { this->exx_lri.set_Cs(Cs_in, this->info.C_threshold); }
71+
void reset_Vs(const std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>& Vs_in) { this->exx_lri.set_Vs(Vs_in, this->info.V_threshold); }
72+
//std::vector<std::vector<int>> get_abfs_nchis() const;
7173

7274
std::vector< std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> Hexxs;
73-
double Eexx;
75+
double Eexx;
7476
ModuleBase::matrix force_exx;
7577
ModuleBase::matrix stress_exx;
76-
78+
7779

7880
private:
7981
const Exx_Info::Exx_Info_RI &info;
8082
MPI_Comm mpi_comm;
8183
const K_Vectors *p_kv = nullptr;
82-
std::vector<double> orb_cutoff_;
84+
std::vector<double> orb_cutoff_;
8385

8486
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> lcaos;
8587
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> abfs;
@@ -89,16 +91,16 @@ class Exx_LRI
8991
RI::Exx<TA,Tcell,Ndim,Tdata> exx_lri;
9092

9193
void post_process_Hexx( std::map<TA, std::map<TAC, RI::Tensor<Tdata>>> &Hexxs_io ) const;
92-
double post_process_Eexx(const double& Eexx_in) const;
94+
double post_process_Eexx(const double& Eexx_in) const;
9395

9496
friend class RPA_LRI<double, Tdata>;
9597
friend class RPA_LRI<std::complex<double>, Tdata>;
9698
friend class Exx_LRI_Interface<double, Tdata>;
9799
friend class Exx_LRI_Interface<std::complex<double>, Tdata>;
98-
friend class LR::ESolver_LR<double, double>;
99-
friend class LR::ESolver_LR<std::complex<double>, double>;
100-
friend class LR::OperatorLREXX<double>;
101-
friend class LR::OperatorLREXX<std::complex<double>>;
100+
friend class LR::ESolver_LR<double, double>;
101+
friend class LR::ESolver_LR<std::complex<double>, double>;
102+
friend class LR::OperatorLREXX<double>;
103+
friend class LR::OperatorLREXX<std::complex<double>>;
102104
};
103105

104106
#include "Exx_LRI.hpp"

source/module_ri/Exx_LRI.hpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
#include <string>
2727

2828
template<typename Tdata>
29-
void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in,
29+
void Exx_LRI<Tdata>::init(const MPI_Comm &mpi_comm_in,
3030
const UnitCell &ucell,
31-
const K_Vectors &kv_in,
31+
const K_Vectors &kv_in,
3232
const LCAO_Orbitals& orb)
3333
{
3434
ModuleBase::TITLE("Exx_LRI","init");
@@ -130,7 +130,7 @@ void Exx_LRI<Tdata>::cal_exx_ions(const UnitCell& ucell,
130130
this->exx_lri.set_parallel(this->mpi_comm, atoms_pos, latvec, period);
131131

132132
// std::max(3) for gamma_only, list_A2 should contain cell {-1,0,1}. In the future distribute will be neighbour.
133-
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
133+
const std::array<Tcell,Ndim> period_Vs = LRI_CV_Tools::cal_latvec_range<Tcell>(1+this->info.ccp_rmesh_times, ucell, orb_cutoff_);
134134
const std::pair<std::vector<TA>, std::vector<std::vector<std::pair<TA,std::array<Tcell,Ndim>>>>>
135135
list_As_Vs = RI::Distribute_Equally::distribute_atoms_periods(this->mpi_comm, atoms, period_Vs, 2, false);
136136

@@ -237,7 +237,7 @@ void Exx_LRI<Tdata>::cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, R
237237
}
238238
this->Eexx = post_process_Eexx(this->Eexx);
239239
this->exx_lri.set_symmetry(false, {});
240-
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");
240+
ModuleBase::timer::tick("Exx_LRI", "cal_exx_elec");
241241
}
242242

243243
template<typename Tdata>
@@ -283,11 +283,6 @@ void Exx_LRI<Tdata>::cal_exx_force(const int& nat)
283283
ModuleBase::TITLE("Exx_LRI","cal_exx_force");
284284
ModuleBase::timer::tick("Exx_LRI", "cal_exx_force");
285285

286-
if (!this->exx_lri.flag_finish.D)
287-
{
288-
ModuleBase::WARNING_QUIT("Force_Stress_LCAO", "Cannot calculate EXX force when the first PBE loop is not converged.");
289-
}
290-
291286
this->force_exx.create(nat, Ndim);
292287
for(int is=0; is<PARAM.inp.nspin; ++is)
293288
{
@@ -328,6 +323,7 @@ void Exx_LRI<Tdata>::cal_exx_stress(const double& omega, const double& lat0)
328323
ModuleBase::timer::tick("Exx_LRI", "cal_exx_stress");
329324
}
330325

326+
/*
331327
template<typename Tdata>
332328
std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
333329
{
@@ -341,5 +337,6 @@ std::vector<std::vector<int>> Exx_LRI<Tdata>::get_abfs_nchis() const
341337
}
342338
return abfs_nchis;
343339
}
340+
*/
344341

345342
#endif

0 commit comments

Comments
 (0)