@@ -136,9 +136,44 @@ template <typename R, typename Compare> void local_sort(R &r, Compare &&comp) {
136
136
}
137
137
}
138
138
139
+ template <typename T, typename Compare>
140
+ void local_merge (buffer<T> &v, std::vector<std::size_t > chunks,
141
+ Compare &&comp) {
142
+
143
+ std::exclusive_scan (chunks.begin (), chunks.end (), chunks.begin (), 0 );
144
+
145
+ while (chunks.size () > 1 ) {
146
+ std::size_t segno = chunks.size ();
147
+ std::vector<std::size_t > next_chunks;
148
+ for (std::size_t i = 0 ; i < segno / 2 ; i++) {
149
+ auto first = v.begin () + chunks[2 * i];
150
+ auto middle = v.begin () + chunks[2 * i + 1 ];
151
+ auto last = (2 * i + 2 < segno) ? v.begin () + chunks[2 * i + 2 ] : v.end ();
152
+ if (mhp::use_sycl ()) {
153
+ #ifdef SYCL_LANGUAGE_VERSION
154
+ auto dfirst = dr::__detail::direct_iterator (first);
155
+ auto dmiddle = dr::__detail::direct_iterator (middle);
156
+ auto dlast = dr::__detail::direct_iterator (last);
157
+ oneapi::dpl::inplace_merge (dpl_policy (), dfirst, dmiddle, dlast, comp);
158
+ #else
159
+ assert (false );
160
+ #endif
161
+ } else {
162
+ std::inplace_merge (first, middle, last, comp);
163
+ }
164
+ next_chunks.push_back (chunks[2 * i]);
165
+ }
166
+ if (segno % 2 == 1 ) {
167
+ next_chunks.push_back (chunks[segno - 1 ]);
168
+ }
169
+ std::swap (chunks, next_chunks);
170
+ }
171
+ }
172
+
139
173
template <typename Compare>
140
- void _find_split_idx (std::size_t &vidx, std::size_t &segidx, Compare &&comp,
141
- auto &ls, auto &vec_v, auto &vec_i, auto &vec_s) {
174
+ void _find_split_idx (std::size_t &vidx, Compare &&comp, auto &ls, auto &vec_v,
175
+ auto &vec_i, auto &vec_s) {
176
+ std::size_t segidx = 0 ;
142
177
while (vidx < default_comm ().size () && segidx < rng::size (ls)) {
143
178
if (comp (vec_v[vidx - 1 ], ls[segidx])) {
144
179
vec_i[vidx] = segidx;
@@ -205,7 +240,7 @@ void splitters(Seg &lsegment, Compare &&comp,
205
240
vec_split_v[_i] = vec_gmedians[global_median_idx];
206
241
}
207
242
208
- std::size_t segidx = 0 , vidx = 1 ;
243
+ std::size_t vidx = 1 ;
209
244
210
245
/* The while loop is executed in host memory, and together with
211
246
* sycl_copy takes most of the execution time of the sort procedure */
@@ -215,13 +250,13 @@ void splitters(Seg &lsegment, Compare &&comp,
215
250
sycl_copy (rng::data (lsegment), rng::data (vec_lseg_tmp),
216
251
rng::size (lsegment));
217
252
218
- _find_split_idx (vidx, segidx, comp, vec_lseg_tmp, vec_split_v, vec_split_i,
253
+ _find_split_idx (vidx, comp, vec_lseg_tmp, vec_split_v, vec_split_i,
219
254
vec_split_s);
220
255
#else
221
256
assert (false );
222
257
#endif
223
258
} else {
224
- _find_split_idx (vidx, segidx, comp, lsegment, vec_split_v, vec_split_i,
259
+ _find_split_idx (vidx, comp, lsegment, vec_split_v, vec_split_i,
225
260
vec_split_s);
226
261
}
227
262
@@ -392,9 +427,8 @@ void dist_sort(R &r, Compare &&comp) {
392
427
default_comm ().alltoallv (lsegment, vec_split_s, vec_split_i, vec_recvdata,
393
428
vec_rsizes, vec_rindices);
394
429
395
- /* TODO: vec recvdata is partially sorted, implementation of merge on GPU is
396
- * desirable */
397
- __detail::local_sort (vec_recvdata, comp);
430
+ __detail::local_merge (vec_recvdata, vec_rsizes, comp);
431
+
398
432
// MPI_Wait(&req_recvelems, MPI_STATUS_IGNORE);
399
433
400
434
_total_elems = std::reduce (vec_recv_elems.begin (), vec_recv_elems.end ());
0 commit comments