Skip to content

Commit ab87e38

Browse files
authored
refactor stress_ewa (#6669)
1 parent 4e37d33 commit ab87e38

File tree

1 file changed

+14
-43
lines changed

1 file changed

+14
-43
lines changed

source/source_pw/module_pwdft/stress_func_ewa.cpp

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,18 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
6464
}
6565
// else fact=1.0;
6666

67-
#ifdef _OPENMP
6867
#pragma omp parallel
6968
{
70-
int num_threads = omp_get_num_threads();
71-
int thread_id = omp_get_thread_num();
7269
ModuleBase::matrix local_sigma(3, 3);
73-
FPTYPE local_sdewald = 0.0;
74-
#else
75-
int num_threads = 1;
76-
int thread_id = 0;
77-
ModuleBase::matrix& local_sigma = sigma;
78-
FPTYPE& local_sdewald = sdewald;
79-
#endif
80-
81-
// Calculate ig range of this thread, avoid thread sync
82-
int ig=0;
83-
int ig_end=0;
84-
ModuleBase::TASK_DIST_1D(num_threads, thread_id, rho_basis->npw, ig, ig_end);
85-
ig_end = ig + ig_end;
70+
FPTYPE local_sdewald = 0;
8671

8772
FPTYPE g2,g2a;
8873
FPTYPE arg;
8974
std::complex<FPTYPE> rhostar;
9075
FPTYPE sewald;
9176

92-
for(; ig < ig_end; ig++)
77+
#pragma omp for
78+
for(int ig = 0; ig < rho_basis->npw; ig++)
9379
{
9480
if(ig == ig0)
9581
{
@@ -136,34 +122,28 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
136122

137123
if(ig0 >= 0)
138124
{
139-
r = new ModuleBase::Vector3<FPTYPE>[mxr];
140-
r2 = new FPTYPE[mxr];
141-
irr = new int[mxr];
125+
std::vector<ModuleBase::Vector3<FPTYPE>> r(mxr);
126+
std::vector<FPTYPE> r2(mxr);
127+
std::vector<int> irr(mxr);
142128

143129
FPTYPE sqa = sqrt(alpha);
144130
FPTYPE sq8a_2pi = sqrt(8 * alpha / (ModuleBase::TWO_PI));
145131
rmax = 4.0/sqa/ucell.lat0;
146132

147-
// collapse it, ia, jt, ja loop into a single loop
148-
long long ijat;
149-
long long ijat_end;
150-
int it=0;
151-
int i=0;
152-
int jt=0;
153-
int j=0;
154-
155-
ModuleBase::TASK_DIST_1D(num_threads, thread_id, (long long)ucell.nat * ucell.nat, ijat, ijat_end);
156-
ijat_end = ijat + ijat_end;
157-
ucell.ijat2iaitjajt(ijat, &i, &it, &j, &jt);
158-
159-
while (ijat < ijat_end)
133+
#pragma omp for
134+
for(long long ijat = 0; ijat < ucell.nat * ucell.nat; ijat++)
160135
{
136+
int it=0;
137+
int i=0;
138+
int jt=0;
139+
int j=0;
140+
ucell.ijat2iaitjajt(ijat, &i, &it, &j, &jt);
161141
if (ucell.atoms[it].na != 0 && ucell.atoms[jt].na != 0)
162142
{
163143
//calculate tau[na]-tau[nb]
164144
d_tau = ucell.atoms[it].tau[i] - ucell.atoms[jt].tau[j];
165145
//generates nearest-neighbors shells
166-
H_Ewald_pw::rgen(d_tau, rmax, irr, ucell.latvec, ucell.G, r, r2, nrm);
146+
H_Ewald_pw::rgen(d_tau, rmax, irr.data(), ucell.latvec, ucell.G, r.data(), r2.data(), nrm);
167147
for(int nr=0; nr<nrm; nr++)
168148
{
169149
rr=sqrt(r2[nr]) * ucell.lat0;
@@ -183,17 +163,9 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
183163
}//end l
184164
}//end nr
185165
}
186-
187-
++ijat;
188-
ucell.step_jajtiait(&j, &jt, &i, &it);
189166
}
190-
191-
delete[] r;
192-
delete[] r2;
193-
delete[] irr;
194167
}//end if
195168

196-
#ifdef _OPENMP
197169
#pragma omp critical(stress_ewa_reduce)
198170
{
199171
sdewald += local_sdewald;
@@ -206,7 +178,6 @@ void Stress_Func<FPTYPE, Device>::stress_ewa(const UnitCell& ucell,
206178
}
207179
}
208180
}
209-
#endif
210181

211182
for(int l=0;l<3;l++)
212183
{

0 commit comments

Comments
 (0)