Skip to content

Commit 3cad424

Browse files
authored
Refactor: Combine some checking functions in DeePKS. (#6285)
* Combine some checking functions in DeePKS. * Remove check_f_delta(). * Remove check_o_delta(). * Update vdpre_ref.dat * Fix a bug. * Fix a merge bug. * Reduce the memory cost for DeePKS GO-UT.
1 parent 83aee57 commit 3cad424

File tree

122 files changed

+33268
-20427
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+33268
-20427
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ OBJS_CELL=atom_pseudo.o\
201201

202202
OBJS_DEEPKS=LCAO_deepks.o\
203203
deepks_basic.o\
204+
deepks_check.o\
204205
deepks_descriptor.o\
205206
deepks_force.o\
206207
deepks_fpre.o\

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,14 @@ void Force_LCAO<double>::ftable(const bool isforce,
272272
#ifdef __MLALGO
273273
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
274274
{
275-
DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
275+
std::ofstream ofs_f("F_delta.dat");
276+
std::ofstream ofs_s("stress_delta.dat");
277+
ofs_f << std::setprecision(10);
278+
ofs_s << std::setprecision(10);
279+
fvnl_dalpha.print(ofs_f);
280+
ofs_f.close();
281+
svnl_dalpha.print(ofs_s);
282+
ofs_s.close();
276283
}
277284
#endif
278285

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,14 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
312312
#ifdef __MLALGO
313313
if (PARAM.inp.deepks_scf && PARAM.inp.deepks_out_unittest)
314314
{
315-
DeePKS_domain::check_f_delta(ucell.nat, fvnl_dalpha, svnl_dalpha);
315+
std::ofstream ofs_f("F_delta.dat");
316+
std::ofstream ofs_s("stress_delta.dat");
317+
ofs_f << std::setprecision(10);
318+
ofs_s << std::setprecision(10);
319+
fvnl_dalpha.print(ofs_f);
320+
ofs_f.close();
321+
svnl_dalpha.print(ofs_s);
322+
ofs_s.close();
316323
}
317324
#endif
318325

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ if(ENABLE_MLALGO)
22
list(APPEND objects
33
LCAO_deepks.cpp
44
deepks_basic.cpp
5+
deepks_check.cpp
56
deepks_descriptor.cpp
67
deepks_force.cpp
78
deepks_fpre.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#ifdef __MLALGO
55

66
#include "deepks_basic.h"
7+
#include "deepks_check.h"
78
#include "deepks_descriptor.h"
89
#include "deepks_force.h"
910
#include "deepks_fpre.h"

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
175175

176176
if (PARAM.inp.deepks_out_unittest)
177177
{
178-
DeePKS_domain::check_gdmx(gdmx);
179-
DeePKS_domain::check_gvx(gvx, rank);
178+
DeePKS_domain::check_tensor<double>(gdmx, "gdmx.dat", rank);
179+
DeePKS_domain::check_tensor<double>(gvx, "gvx.dat", rank);
180180
}
181181
}
182182
}
@@ -198,8 +198,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
198198

199199
if (PARAM.inp.deepks_out_unittest)
200200
{
201-
DeePKS_domain::check_gdmepsl(gdmepsl);
202-
DeePKS_domain::check_gvepsl(gvepsl, rank);
201+
DeePKS_domain::check_tensor<double>(gdmepsl, "gdmepsl.dat", rank);
202+
DeePKS_domain::check_tensor<double>(gvepsl, "gvepsl.dat", rank);
203203
}
204204
}
205205
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifdef __MLALGO
2+
3+
#include "deepks_check.h"
4+
5+
template <typename T>
6+
void DeePKS_domain::check_tensor(const torch::Tensor& tensor, const std::string& filename, const int rank)
7+
{
8+
if (rank != 0)
9+
{
10+
return;
11+
}
12+
using T_tensor = typename std::conditional<std::is_same<T, std::complex<double>>::value, c10::complex<double>, T>::type;
13+
14+
std::ofstream ofs(filename.c_str());
15+
ofs << std::setprecision(10);
16+
17+
auto sizes = tensor.sizes();
18+
int ndim = sizes.size();
19+
auto data_ptr = tensor.data_ptr<T_tensor>();
20+
int64_t numel = tensor.numel();
21+
22+
// stride for each dimension
23+
std::vector<int64_t> strides(ndim, 1);
24+
for (int i = ndim - 2; i >= 0; --i) {
25+
strides[i] = strides[i + 1] * sizes[i + 1];
26+
}
27+
28+
for (int64_t idx = 0; idx < numel; ++idx) {
29+
// index to multi-dimensional indices
30+
std::vector<int64_t> indices(ndim);
31+
int64_t tmp = idx;
32+
for (int d = 0; d < ndim; ++d) {
33+
indices[d] = tmp / strides[d];
34+
tmp = tmp % strides[d];
35+
}
36+
37+
T_tensor tmp_val = data_ptr[idx];
38+
T* tmp_ptr = reinterpret_cast<T*>(&tmp_val);
39+
ofs << *tmp_ptr;
40+
41+
// print space or newline
42+
if ( ( (idx+1) % sizes[ndim-1] ) == 0 ) {
43+
ofs << std::endl;
44+
} else {
45+
ofs << " ";
46+
}
47+
}
48+
49+
ofs.close();
50+
}
51+
52+
template void DeePKS_domain::check_tensor<double>(const torch::Tensor& tensor, const std::string& filename, const int rank);
53+
template void DeePKS_domain::check_tensor<std::complex<double>>(const torch::Tensor& tensor, const std::string& filename, const int rank);
54+
55+
#endif
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifndef DEEPKS_CHECK_H
2+
#define DEEPKS_CHECK_H
3+
4+
#ifdef __MLALGO
5+
6+
#include <string>
7+
#include <torch/script.h>
8+
#include <torch/torch.h>
9+
10+
namespace DeePKS_domain
11+
{
12+
//------------------------
13+
// deepks_check.cpp
14+
//------------------------
15+
16+
// This file contains subroutines for checking files
17+
18+
// There are 1 subroutines in this file:
19+
// 1. check_tensor, which is used for tensor data checking
20+
21+
template <typename T>
22+
void check_tensor(const torch::Tensor& tensor, const std::string& filename, const int rank);
23+
24+
} // namespace DeePKS_domain
25+
26+
#endif
27+
#endif

source/module_hamilt_lcao/module_deepks/deepks_force.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -263,32 +263,6 @@ void DeePKS_domain::cal_f_delta(const hamilt::HContainer<double>* dmr,
263263
return;
264264
}
265265

266-
// prints forces and stress from DeePKS (LCAO)
267-
void DeePKS_domain::check_f_delta(const int nat, ModuleBase::matrix& f_delta, ModuleBase::matrix& svnl_dalpha)
268-
{
269-
ModuleBase::TITLE("DeePKS_domain", "check_F_delta");
270-
271-
std::ofstream ofs("F_delta.dat");
272-
ofs << std::setprecision(10);
273-
274-
for (int iat = 0; iat < nat; iat++)
275-
{
276-
ofs << f_delta(iat, 0) << " " << f_delta(iat, 1) << " " << f_delta(iat, 2) << std::endl;
277-
}
278-
279-
std::ofstream ofs1("stress_delta.dat");
280-
ofs1 << std::setprecision(10);
281-
for (int ipol = 0; ipol < 3; ipol++)
282-
{
283-
for (int jpol = 0; jpol < 3; jpol++)
284-
{
285-
ofs1 << svnl_dalpha(ipol, jpol) << " ";
286-
}
287-
ofs1 << std::endl;
288-
}
289-
return;
290-
}
291-
292266
template void DeePKS_domain::cal_f_delta<double>(const hamilt::HContainer<double>* dmr,
293267
const UnitCell& ucell,
294268
const LCAO_Orbitals& orb,

source/module_hamilt_lcao/module_deepks/deepks_force.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ namespace DeePKS_domain
2121
// This file contains subroutines for calculating F_delta,
2222
// which is defind as sum_mu,nu rho_mu,nu d/dX (<chi_mu|alpha>V(D)<alpha|chi_nu>)
2323

24-
// There are 2 subroutines in this file:
24+
// There are 1 subroutine in this file:
2525
// 1. cal_f_delta, which is used for F_delta calculation
26-
// 2. check_f_delta, which prints F_delta into F_delta.dat for checking
2726

2827
template <typename TK>
2928
void cal_f_delta(const hamilt::HContainer<double>* dmr,
@@ -39,8 +38,6 @@ void cal_f_delta(const hamilt::HContainer<double>* dmr,
3938
ModuleBase::matrix& f_delta,
4039
const bool isstress,
4140
ModuleBase::matrix& svnl_dalpha);
42-
43-
void check_f_delta(const int nat, ModuleBase::matrix& f_delta, ModuleBase::matrix& svnl_dalpha);
4441
} // namespace DeePKS_domain
4542

4643
#endif

0 commit comments

Comments
 (0)