Skip to content

Commit 2b1e662

Browse files
authored
Add Mt_Gemm for the nonlocal_pw (#6253)
* change globalv * add dsp for the nonlocal_pw * modify basis name
1 parent 262be9e commit 2b1e662

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

source/module_hamilt_pw/hamilt_pwdft/VNL_in_pw.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,13 @@ void pseudopot_cell_vnl::init(const UnitCell& ucell,
271271
resmem_sh_op()(s_tab, this->tab.getSize());
272272
resmem_ch_op()(c_vkb, nkb * npwx);
273273
}
274+
#ifdef __DSP
275+
base_device::memory::resize_memory_op_mt<std::complex<double>, base_device::DEVICE_CPU>()
276+
(this->z_vkb, this->vkb.size, "Nonlocal<PW>::ps");
277+
memcpy(this->z_vkb,this->vkb.c,this->vkb.size*16);
278+
#else
274279
this->z_vkb = this->vkb.c;
280+
#endif
275281
this->d_tab = this->tab.ptr;
276282
// There's no need to delete double precision pointers while in a CPU environment.
277283
}

source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ void Nonlocal<OperatorPW<T, Device>>::add_nonlocal_pp(T *hpsi_in, const T *becp,
185185
int npm = m;
186186
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
187187
// denghui replace 2022-10-20
188-
gemm_op()(
188+
#ifdef __DSP
189+
ModuleBase::gemm_op_mt<T, Device>()
190+
#else
191+
gemm_op()
192+
#endif
193+
(
189194
transa,
190195
transb,
191196
this->npw,
@@ -259,7 +264,12 @@ void Nonlocal<OperatorPW<T, Device>>::act(
259264
int npm = nbands;
260265
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
261266
// denghui replace 2022-10-20
262-
gemm_op()(
267+
#ifdef __DSP
268+
ModuleBase::gemm_op_mt<T, Device>()
269+
#else
270+
gemm_op()
271+
#endif
272+
(
263273
transa,
264274
transb,
265275
nkb,

source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,13 @@ class Nonlocal<OperatorPW<T, Device>> : public OperatorPW<T, Device>
8989
using gemm_op = ModuleBase::gemm_op<T, Device>;
9090
using nonlocal_op = nonlocal_pw_op<Real, Device>;
9191
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;
92+
#ifdef __DSP
93+
using resmem_complex_op = base_device::memory::resize_memory_op_mt<T, Device>;
94+
using delmem_complex_op = base_device::memory::delete_memory_op_mt<T, Device>;
95+
#else
9296
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
9397
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
98+
#endif
9499
using syncmem_complex_h2d_op = base_device::memory::synchronize_memory_op<T, Device, base_device::DEVICE_CPU>;
95100

96101
T one{1, 0};

source/module_io/read_input_item_system.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,12 @@ void ReadInput::item_system()
812812
const std::string warningstr = nofound_str(avail_list, "precision");
813813
ModuleBase::WARNING_QUIT("ReadInput", warningstr);
814814
}
815-
815+
if (para.inp.precision == "single" && para.inp.basis_type == "lcao")
816+
{
817+
ModuleBase::WARNING_QUIT(
818+
"ReadInput",
819+
"Single precision is not supported for NAO basis,\nPlease use double precision for NAO basis.\n");
820+
}
816821
// cpu single precision is not supported while float_fftw lib is not available
817822
if (para.inp.device == "cpu" && para.inp.precision == "single")
818823
{

0 commit comments

Comments
 (0)