diff --git a/source/module_hamilt_lcao/module_deepks/deepks_check.cpp b/source/module_hamilt_lcao/module_deepks/deepks_check.cpp index f8dc5c4269..f7f30d79c5 100644 --- a/source/module_hamilt_lcao/module_deepks/deepks_check.cpp +++ b/source/module_hamilt_lcao/module_deepks/deepks_check.cpp @@ -49,6 +49,7 @@ void DeePKS_domain::check_tensor(const torch::Tensor& tensor, const std::string& ofs.close(); } +template void DeePKS_domain::check_tensor(const torch::Tensor& tensor, const std::string& filename, const int rank); template void DeePKS_domain::check_tensor(const torch::Tensor& tensor, const std::string& filename, const int rank); template void DeePKS_domain::check_tensor>(const torch::Tensor& tensor, const std::string& filename, const int rank); diff --git a/source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp b/source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp index bb355837b8..d6e4cd4bc6 100644 --- a/source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp +++ b/source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp @@ -48,8 +48,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal, torch::Tensor v_delta_pdm = torch::zeros({nks, nlocal, nlocal, inlmax, (2 * lmaxd + 1), (2 * lmaxd + 1)}, torch::dtype(dtype)); - auto accessor - = v_delta_pdm.accessor::value, double, c10::complex>, 6>(); + auto accessor = v_delta_pdm.accessor(); DeePKS_domain::iterate_ad2( ucell, @@ -108,7 +107,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal, = (kvec_d[ik] * ModuleBase::Vector3(dR1 - dR2)) * ModuleBase::TWO_PI; kphase = std::complex(cos(arg), sin(arg)); } - TK_tensor* kpase_ptr = reinterpret_cast(&kphase); + TK* kpase_ptr = reinterpret_cast(&kphase); for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0) { for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0) @@ -119,9 +118,10 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal, { for (int m2 = 0; m2 < nm; ++m2) // nm = 1 for s, 3 for p, 5 for d { - TK_tensor tmp = overlap_1->get_value(iw1, ib + m1) + TK tmp = overlap_1->get_value(iw1, ib + m1) * overlap_2->get_value(iw2, ib + m2) * *kpase_ptr; - accessor[ik][iw1_all][iw2_all][inl][m1][m2] += tmp; + TK_tensor tmp_tensor = TK_tensor(tmp); + accessor[ik][iw1_all][iw2_all][inl][m1][m2] += tmp_tensor; } } ib += nm; @@ -193,8 +193,7 @@ void DeePKS_domain::prepare_phialpha(const int nlocal, int nlmax = inlmax / nat; int mmax = 2 * lmaxd + 1; phialpha_out = torch::zeros({nat, nlmax, nks, nlocal, mmax}, dtype); - auto accessor - = phialpha_out.accessor::value, double, c10::complex>, 5>(); + auto accessor = phialpha_out.accessor(); DeePKS_domain::iterate_ad1( ucell,