Skip to content

Commit 83aee57

Browse files
authored
Refactor: Migrate ML KEDF Descriptor Calculation & Output to module_io (#6287)
* Refactor: Move the code about calculation and output of ML KEDF descriptors to module_io * Refactor: Remove ml_data.h, ml_data.cpp, and ml_data_descriptor.cpp * Doc: Update the doc of of_ml_gene_data.
1 parent 42042b4 commit 83aee57

File tree

16 files changed

+973
-1053
lines changed

16 files changed

+973
-1053
lines changed

docs/advanced/input_files/input-main.md

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,8 +2378,29 @@ Warning: this function is not robust enough for the current version. Please try
23782378
### of_ml_gene_data
23792379

23802380
- **Type**: Boolean
2381-
- **Availability**: OFDFT
2382-
- **Description**: Generate training data or not.
2381+
- **Availability**: Used only for KSDFT with plane wave basis
2382+
- **Description**: Controls the generation of machine learning training data. When enabled, training data in `.npy` format will be saved in the directory `OUT.${suffix}/MLKEDF_Descriptors/`. The generated descriptors are categorized as follows:
2383+
- Local/Semilocal Descriptors. Files are named as `{var}.npy`, where `{var}` corresponds to the descriptor type:
2384+
- `gamma`: Enabled by [of_ml_gamma](#of_ml_gamma)
2385+
- `p`: Enabled by [of_ml_p](#of_ml_p)
2386+
- `q`: Enabled by [of_ml_q](#of_ml_q)
2387+
- `tanhp`: Enabled by [of_ml_tanhp](#of_ml_tanhp)
2388+
- `tanhq`: Enabled by [of_ml_tanhq](#of_ml_tanhq)
2389+
- Nonlocal Descriptors generated using kernels configured via [of_ml_nkernel](#of_ml_nkernel), [of_ml_kernel](#of_ml_kernel), and [of_ml_kernel_scaling](#of_ml_kernel_scaling). Files follow the naming convention `{var}_{kernel_type}_{kernel_scaling}.npy`, where `{kernel_type}` and `{kernel_scaling}` are specified by [of_ml_kernel](#of_ml_kernel), and [of_ml_kernel_scaling](#of_ml_kernel_scaling), respectively, and `{val}` denotes the kind of the descriptor, including
2390+
- `gammanl`: Enabled by [of_ml_gammanl](#of_ml_gammanl)
2391+
- `pnl`: Enabled by [of_ml_pnl](#of_ml_pnl)
2392+
- `qnl`: Enabled by [of_ml_qnl](#of_ml_qnl)
2393+
- `xi`: Enabled by [of_ml_xi](#of_ml_xi)
2394+
- `tanhxi`: Enabled by [of_ml_tanhxi](#of_ml_tanhxi)
2395+
- `tanhxi_nl`: Enabled by [of_ml_tanhxi_nl](#of_ml_tanhxi_nl)
2396+
- `tanh_pnl`: Enabled by [of_ml_tanh_pnl](#of_ml_tanh_pnl)
2397+
- `tanh_qnl`: Enabled by [of_ml_tanh_qnl](#of_ml_tanh_qnl)
2398+
- `tanhp_nl`: Enabled by [of_ml_tanhp_nl](#of_ml_tanhp_nl)
2399+
- `tanhq_nl`: Enabled by [of_ml_tanhq_nl](#of_ml_tanhq_nl)
2400+
- Training Targets, including key quantum mechanical quantities:
2401+
- `enhancement.npy`: Pauli energy enhancement factor $T_\theta/T_{\rm{TF}}$, where $T_{\rm{TF}}$ is the Thomas-Fermi functional
2402+
- `pauli.npy`: Pauli potential $V_\theta$
2403+
- `veff.npy`: Effective potential
23832404
- **Default**: False
23842405

23852406
### of_ml_device

source/module_base/global_file.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,47 @@ void ModuleBase::Global_File::make_dir_out(
153153
#endif
154154
}
155155

156+
if(PARAM.inp.of_ml_gene_data == 1)
157+
{
158+
int make_dir_descrip = 0;
159+
std::string command1 = "test -d " + PARAM.globalv.global_mlkedf_descriptor_dir + " || mkdir " + PARAM.globalv.global_mlkedf_descriptor_dir;
160+
161+
times = 0;
162+
while(times<GlobalV::NPROC)
163+
{
164+
if(rank==times)
165+
{
166+
if ( system( command1.c_str() ) == 0 )
167+
{
168+
std::cout << " MAKE THE MLKEDF DESCRIPTOR DIR : " << PARAM.globalv.global_mlkedf_descriptor_dir << std::endl;
169+
make_dir_descrip = 1;
170+
}
171+
else
172+
{
173+
std::cout << " PROC " << rank << " CAN NOT MAKE THE MLKEDF DESCRIPTOR DIR !!! " << std::endl;
174+
make_dir_descrip = 0;
175+
}
176+
}
177+
#ifdef __MPI
178+
Parallel_Reduce::reduce_all(make_dir_descrip);
179+
#endif
180+
if(make_dir_descrip > 0)
181+
{
182+
break;
183+
}
184+
++times;
185+
}
186+
187+
#ifdef __MPI
188+
if(make_dir_descrip == 0)
189+
{
190+
std::cout << " CAN NOT MAKE THE MLKEDF DESCRIPTOR DIR......." << std::endl;
191+
ModuleBase::QUIT();
192+
}
193+
MPI_Barrier(MPI_COMM_WORLD);
194+
#endif
195+
}
196+
156197
// mohan add 2010-09-12
157198
if(out_alllog)
158199
{

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
#include <iostream>
3838

3939
#ifdef __MLALGO
40-
#include "module_hamilt_pw/hamilt_ofdft/ml_data.h"
40+
#include "module_io/write_mlkedf_descriptors.h"
4141
#endif
4242

4343
#include <ATen/kernels/blas.h>
@@ -1003,30 +1003,31 @@ void ESolver_KS_PW<T, Device>::after_all_runners(UnitCell& ucell)
10031003
{
10041004
this->pelec->pot->update_from_charge(&this->chr, &ucell);
10051005

1006-
ML_data ml_data;
1007-
ml_data.set_para(this->chr.nrxx,
1008-
PARAM.inp.nelec,
1009-
PARAM.inp.of_tf_weight,
1010-
PARAM.inp.of_vw_weight,
1011-
PARAM.inp.of_ml_chi_p,
1012-
PARAM.inp.of_ml_chi_q,
1013-
PARAM.inp.of_ml_chi_xi,
1014-
PARAM.inp.of_ml_chi_pnl,
1015-
PARAM.inp.of_ml_chi_qnl,
1016-
PARAM.inp.of_ml_nkernel,
1017-
PARAM.inp.of_ml_kernel,
1018-
PARAM.inp.of_ml_kernel_scaling,
1019-
PARAM.inp.of_ml_yukawa_alpha,
1020-
PARAM.inp.of_ml_kernel_file,
1021-
ucell.omega,
1022-
this->pw_rho);
1023-
1024-
ml_data.generateTrainData_KS(this->kspw_psi,
1025-
this->pelec,
1026-
this->pw_wfc,
1027-
this->pw_rho,
1028-
ucell,
1029-
this->pelec->pot->get_effective_v(0));
1006+
ModuleIO::Write_MLKEDF_Descriptors write_mlkedf_desc;
1007+
write_mlkedf_desc.cal_tool->set_para(this->chr.nrxx,
1008+
PARAM.inp.nelec,
1009+
PARAM.inp.of_tf_weight,
1010+
PARAM.inp.of_vw_weight,
1011+
PARAM.inp.of_ml_chi_p,
1012+
PARAM.inp.of_ml_chi_q,
1013+
PARAM.inp.of_ml_chi_xi,
1014+
PARAM.inp.of_ml_chi_pnl,
1015+
PARAM.inp.of_ml_chi_qnl,
1016+
PARAM.inp.of_ml_nkernel,
1017+
PARAM.inp.of_ml_kernel,
1018+
PARAM.inp.of_ml_kernel_scaling,
1019+
PARAM.inp.of_ml_yukawa_alpha,
1020+
PARAM.inp.of_ml_kernel_file,
1021+
ucell.omega,
1022+
this->pw_rho);
1023+
1024+
write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir,
1025+
this->kspw_psi,
1026+
this->pelec,
1027+
this->pw_wfc,
1028+
this->pw_rho,
1029+
ucell,
1030+
this->pelec->pot->get_effective_v(0));
10301031
}
10311032
#endif
10321033
}

source/module_hamilt_pw/hamilt_ofdft/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ if(ENABLE_MLALGO)
2121
kedf_ml.cpp
2222
kedf_ml_pot.cpp
2323
kedf_ml_label.cpp
24-
ml_data.cpp
25-
ml_data_descriptor.cpp
2624
ml_tools/nn_of.cpp
2725
)
2826

source/module_hamilt_pw/hamilt_ofdft/kedf_ml.cpp

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "kedf_ml.h"
44

5+
#include "npy.hpp"
56
#include "module_base/parallel_reduce.h"
67
#include "module_base/global_function.h"
78
#include "module_hamilt_pw/hamilt_pwdft/global.h"
@@ -100,15 +101,15 @@ void KEDF_ML::set_para(
100101

101102
if (PARAM.inp.of_kinetic == "ml" || PARAM.inp.of_ml_gene_data == 1)
102103
{
103-
this->ml_data = new ML_data;
104+
this->cal_tool = new ModuleIO::Cal_MLKEDF_Descriptors;
104105

105106
this->chi_p = chi_p;
106107
this->chi_q = chi_q;
107108
this->chi_xi = chi_xi;
108109
this->chi_pnl = chi_pnl;
109110
this->chi_qnl = chi_qnl;
110111

111-
this->ml_data->set_para(nx, nelec, tf_weight, vw_weight, chi_p, chi_q,
112+
this->cal_tool->set_para(nx, nelec, tf_weight, vw_weight, chi_p, chi_q,
112113
chi_xi, chi_pnl, chi_qnl, nkernel, kernel_type, kernel_scaling, yukawa_alpha, kernel_file, this->dV * pw_rho->nxyz, pw_rho);
113114
}
114115
}
@@ -186,7 +187,7 @@ void KEDF_ML::ml_potential(const double * const * prho, ModulePW::PW_Basis *pw_r
186187
*/
187188
void KEDF_ML::generateTrainData(const double * const *prho, KEDF_WT &wt, KEDF_TF &tf, ModulePW::PW_Basis *pw_rho, const double *veff)
188189
{
189-
this->ml_data->generateTrainData_WT(prho, wt, tf, pw_rho, veff);
190+
// this->cal_tool->generateTrainData_WT(prho, wt, tf, pw_rho, veff); // Will be fixed in next pr
190191
if (PARAM.inp.of_kinetic == "ml")
191192
{
192193
this->updateInput(prho, pw_rho);
@@ -203,9 +204,8 @@ void KEDF_ML::generateTrainData(const double * const *prho, KEDF_WT &wt, KEDF_TF
203204

204205
this->get_potential_(prho, pw_rho, potential);
205206

206-
std::cout << "dumpdump\n";
207-
this->dumpTensor(enhancement, "enhancement.npy");
208-
this->dumpMatrix(potential, "potential.npy");
207+
this->dumpTensor("enhancement.npy", enhancement);
208+
this->dumpMatrix("potential.npy", potential);
209209
}
210210
}
211211

@@ -222,7 +222,7 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho)
222222
bool fortran_order = false;
223223

224224
std::vector<double> temp_prho(this->nx);
225-
this->ml_data->loadVector("dir_of_input_rho", temp_prho);
225+
this->loadVector("dir_of_input_rho", temp_prho);
226226
double ** prho = new double *[1];
227227
prho[0] = new double[this->nx];
228228
for (int ir = 0; ir < this->nx; ++ir) prho[0][ir] = temp_prho[ir];
@@ -232,11 +232,9 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho)
232232
std::cout << "WARNING: rho = 0" << std::endl;
233233
}
234234
};
235-
std::cout << "Load rho done" << std::endl;
236235
// ==============================
237236

238237
this->updateInput(prho, pw_rho);
239-
std::cout << "update done" << std::endl;
240238

241239
this->NN_forward(prho, pw_rho, true);
242240

@@ -245,16 +243,13 @@ void KEDF_ML::localTest(const double * const *pprho, ModulePW::PW_Basis *pw_rho)
245243
torch::Tensor gradient_cpu_tensor = this->nn->inputs.grad().to(this->device_CPU).contiguous();
246244
this->gradient_cpu_ptr = gradient_cpu_tensor.data_ptr<double>();
247245

248-
std::cout << "enhancement done" << std::endl;
249-
250246
torch::Tensor enhancement = this->nn->F.reshape({this->nx});
251247
ModuleBase::matrix potential(1, this->nx);
252248

253249
this->get_potential_(prho, pw_rho, potential);
254-
std::cout << "potential done" << std::endl;
255250

256-
this->dumpTensor(enhancement, "enhancement-abacus.npy");
257-
this->dumpMatrix(potential, "potential-abacus.npy");
251+
this->dumpTensor("enhancement-abacus.npy", enhancement);
252+
this->dumpMatrix("potential-abacus.npy", potential);
258253
exit(0);
259254
}
260255

@@ -267,20 +262,20 @@ void KEDF_ML::set_device(std::string device_inpt)
267262
{
268263
if (device_inpt == "cpu")
269264
{
270-
std::cout << "---------- Running NN on CPU ----------" << std::endl;
265+
std::cout << "------------------- Running NN on CPU -------------------" << std::endl;
271266
this->device_type = torch::kCPU;
272267
}
273268
else if (device_inpt == "gpu")
274269
{
275270
if (torch::cuda::cudnn_is_available())
276271
{
277-
std::cout << "---------- Running NN on GPU ----------" << std::endl;
272+
std::cout << "------------------- Running NN on GPU -------------------" << std::endl;
278273
this->device_type = torch::kCUDA;
279274
}
280275
else
281276
{
282-
std::cout << "------ Warning: GPU is unaviable ------" << std::endl;
283-
std::cout << "---------- Running NN on CPU ----------" << std::endl;
277+
std::cout << "--------------- Warning: GPU is unaviable ---------------" << std::endl;
278+
std::cout << "------------------- Running NN on CPU -------------------" << std::endl;
284279
this->device_type = torch::kCPU;
285280
}
286281
}
@@ -331,19 +326,32 @@ void KEDF_ML::NN_forward(const double * const * prho, ModulePW::PW_Basis *pw_rho
331326
}
332327
}
333328

329+
void KEDF_ML::loadVector(std::string filename, std::vector<double> &data)
330+
{
331+
std::vector<long unsigned int> cshape = {(long unsigned) this->cal_tool->nx};
332+
bool fortran_order = false;
333+
npy::LoadArrayFromNumpy(filename, cshape, fortran_order, data);
334+
}
335+
336+
void KEDF_ML::dumpVector(std::string filename, const std::vector<double> &data)
337+
{
338+
const long unsigned cshape[] = {(long unsigned) this->cal_tool->nx}; // shape
339+
npy::SaveArrayAsNumpy(filename, false, 1, cshape, data);
340+
}
341+
334342
/**
335343
* @brief Dump the torch::Tensor into .npy file
336344
*
337345
* @param data torch::Tensor
338346
* @param filename file name
339347
*/
340-
void KEDF_ML::dumpTensor(const torch::Tensor &data, std::string filename)
348+
void KEDF_ML::dumpTensor(std::string filename, const torch::Tensor &data)
341349
{
342350
std::cout << "Dumping " << filename << std::endl;
343351
torch::Tensor data_cpu = data.to(this->device_CPU).contiguous();
344352
std::vector<double> v(data_cpu.data_ptr<double>(), data_cpu.data_ptr<double>() + data_cpu.numel());
345353
// for (int ir = 0; ir < this->nx; ++ir) assert(v[ir] == data[ir].item<double>());
346-
this->ml_data->dumpVector(filename, v);
354+
this->dumpVector(filename, v);
347355
}
348356

349357
/**
@@ -352,12 +360,12 @@ void KEDF_ML::dumpTensor(const torch::Tensor &data, std::string filename)
352360
* @param data matrix
353361
* @param filename file name
354362
*/
355-
void KEDF_ML::dumpMatrix(const ModuleBase::matrix &data, std::string filename)
363+
void KEDF_ML::dumpMatrix(std::string filename, const ModuleBase::matrix &data)
356364
{
357365
std::cout << "Dumping " << filename << std::endl;
358366
std::vector<double> v(data.c, data.c + this->nx);
359367
// for (int ir = 0; ir < this->nx; ++ir) assert(v[ir] == data[ir].item<double>());
360-
this->ml_data->dumpVector(filename, v);
368+
this->dumpVector(filename, v);
361369
}
362370

363371
/**
@@ -372,57 +380,57 @@ void KEDF_ML::updateInput(const double * const * prho, ModulePW::PW_Basis *pw_rh
372380
// std::cout << "updata_input" << std::endl;
373381
if (this->gene_data_label["gamma"][0])
374382
{
375-
this->ml_data->getGamma(prho, this->gamma);
383+
this->cal_tool->getGamma(prho, this->gamma);
376384
}
377385
if (this->gene_data_label["p"][0])
378386
{
379-
this->ml_data->getNablaRho(prho, pw_rho, this->nablaRho);
380-
this->ml_data->getP(prho, pw_rho, this->nablaRho, this->p);
387+
this->cal_tool->getNablaRho(prho, pw_rho, this->nablaRho);
388+
this->cal_tool->getP(prho, pw_rho, this->nablaRho, this->p);
381389
}
382390
if (this->gene_data_label["q"][0])
383391
{
384-
this->ml_data->getQ(prho, pw_rho, this->q);
392+
this->cal_tool->getQ(prho, pw_rho, this->q);
385393
}
386394
if (this->gene_data_label["tanhp"][0])
387395
{
388-
this->ml_data->getTanhP(this->p, this->tanhp);
396+
this->cal_tool->getTanhP(this->p, this->tanhp);
389397
}
390398
if (this->gene_data_label["tanhq"][0])
391399
{
392-
this->ml_data->getTanhQ(this->q, this->tanhq);
400+
this->cal_tool->getTanhQ(this->q, this->tanhq);
393401
}
394402

395403
for (int ik = 0; ik < nkernel; ++ik)
396404
{
397405
if (this->gene_data_label["gammanl"][ik]){
398-
this->ml_data->getGammanl(ik, this->gamma, pw_rho, this->gammanl[ik]);
406+
this->cal_tool->getGammanl(ik, this->gamma, pw_rho, this->gammanl[ik]);
399407
}
400408
if (this->gene_data_label["pnl"][ik]){
401-
this->ml_data->getPnl(ik, this->p, pw_rho, this->pnl[ik]);
409+
this->cal_tool->getPnl(ik, this->p, pw_rho, this->pnl[ik]);
402410
}
403411
if (this->gene_data_label["qnl"][ik]){
404-
this->ml_data->getQnl(ik, this->q, pw_rho, this->qnl[ik]);
412+
this->cal_tool->getQnl(ik, this->q, pw_rho, this->qnl[ik]);
405413
}
406414
if (this->gene_data_label["xi"][ik]){
407-
this->ml_data->getXi(this->gamma, this->gammanl[ik], this->xi[ik]);
415+
this->cal_tool->getXi(this->gamma, this->gammanl[ik], this->xi[ik]);
408416
}
409417
if (this->gene_data_label["tanhxi"][ik]){
410-
this->ml_data->getTanhXi(ik, this->gamma, this->gammanl[ik], this->tanhxi[ik]);
418+
this->cal_tool->getTanhXi(ik, this->gamma, this->gammanl[ik], this->tanhxi[ik]);
411419
}
412420
if (this->gene_data_label["tanhxi_nl"][ik]){
413-
this->ml_data->getTanhXi_nl(ik, this->tanhxi[ik], pw_rho, this->tanhxi_nl[ik]);
421+
this->cal_tool->getTanhXi_nl(ik, this->tanhxi[ik], pw_rho, this->tanhxi_nl[ik]);
414422
}
415423
if (this->gene_data_label["tanh_pnl"][ik]){
416-
this->ml_data->getTanh_Pnl(ik, this->pnl[ik], this->tanh_pnl[ik]);
424+
this->cal_tool->getTanh_Pnl(ik, this->pnl[ik], this->tanh_pnl[ik]);
417425
}
418426
if (this->gene_data_label["tanh_qnl"][ik]){
419-
this->ml_data->getTanh_Qnl(ik, this->qnl[ik], this->tanh_qnl[ik]);
427+
this->cal_tool->getTanh_Qnl(ik, this->qnl[ik], this->tanh_qnl[ik]);
420428
}
421429
if (this->gene_data_label["tanhp_nl"][ik]){
422-
this->ml_data->getTanhP_nl(ik, this->tanhp, pw_rho, this->tanhp_nl[ik]);
430+
this->cal_tool->getTanhP_nl(ik, this->tanhp, pw_rho, this->tanhp_nl[ik]);
423431
}
424432
if (this->gene_data_label["tanhq_nl"][ik]){
425-
this->ml_data->getTanhQ_nl(ik, this->tanhq, pw_rho, this->tanhq_nl[ik]);
433+
this->cal_tool->getTanhQ_nl(ik, this->tanhq, pw_rho, this->tanhq_nl[ik]);
426434
}
427435
}
428436
ModuleBase::timer::tick("KEDF_ML", "updateInput");

0 commit comments

Comments
 (0)