12
12
#include < dr/concepts/concepts.hpp>
13
13
#include < dr/detail/onedpl_direct_iterator.hpp>
14
14
#include < dr/shp/init.hpp>
15
+
16
+ #include < omp.h>
15
17
#include < sycl/sycl.hpp>
16
18
17
19
namespace dr ::shp {
@@ -151,7 +153,7 @@ void sort(R &&r, Compare comp = Compare()) {
151
153
auto &&local_segment = dr::shp::__detail::local (segment);
152
154
153
155
std::size_t *splitter_i = sycl::malloc_shared<std::size_t >(
154
- n_splitters, q.get_device (), shp::context ());
156
+ n_splitters + 1 , q.get_device (), shp::context ());
155
157
splitter_indices.push_back (splitter_i);
156
158
157
159
// Local copy `medians_l` necessary due to [GSD-3893]
@@ -166,6 +168,8 @@ void sort(R &&r, Compare comp = Compare()) {
166
168
167
169
sycl::free (medians_l, shp::context ());
168
170
171
+ splitter_i[n_splitters] = rng::size (local_segment);
172
+
169
173
auto p_first = rng::begin (local_segment);
170
174
auto p_last = p_first;
171
175
for (std::size_t i = 0 ; i < n_splitters; i++) {
@@ -235,20 +239,50 @@ void sort(R &&r, Compare comp = Compare()) {
235
239
dr::shp::__detail::wait (events);
236
240
events.clear ();
237
241
238
- // Sort each of these new segments
239
- for (std::size_t i = 0 ; i < sorted_segments.size (); i++) {
240
- auto &&local_policy =
241
- dr::shp::__detail::dpl_policy (dr::ranges::rank (segments[i]));
242
- T *seg = sorted_segments[i];
243
- std::size_t n_elements = sorted_seg_sizes[i];
242
+ // merge sorted chunks within each of these new segments
244
243
245
- auto e = __detail::sort_async (local_policy, seg, seg + n_elements, comp);
244
+ #pragma omp parallel num_threads(n_segments)
245
+ {
246
+ int t = omp_get_thread_num ();
246
247
247
- events. push_back (e) ;
248
- }
248
+ std::vector<std:: size_t > chunks_ind, chunks_ind2 ;
249
+ chunks_ind. push_back ( 0 );
249
250
250
- dr::shp::__detail::wait (events);
251
- events.clear ();
251
+ std::size_t v = 0 ;
252
+ for (std::size_t i = 0 ; i < n_segments; i++) {
253
+ v += (t == 0 ) ? splitter_indices[i][0 ]
254
+ : splitter_indices[i][t] - splitter_indices[i][t - 1 ];
255
+ chunks_ind.push_back (v);
256
+ }
257
+
258
+ auto _segments = n_segments;
259
+ while (_segments > 1 ) {
260
+ chunks_ind2.push_back (0 );
261
+
262
+ for (int s = 0 ; s < _segments / 2 ; s++) {
263
+
264
+ std::size_t l = (2 * s + 2 < _segments) ? chunks_ind[2 * s + 2 ]
265
+ : sorted_seg_sizes[t];
266
+
267
+ auto first = dr::__detail::direct_iterator (sorted_segments[t] +
268
+ chunks_ind[2 * s]);
269
+ auto middle = dr::__detail::direct_iterator (sorted_segments[t] +
270
+ chunks_ind[2 * s + 1 ]);
271
+ auto last = dr::__detail::direct_iterator (sorted_segments[t] + l);
272
+
273
+ chunks_ind2.push_back (l);
274
+
275
+ oneapi::dpl::inplace_merge (
276
+ __detail::dpl_policy (dr::ranges::rank (segments[t])), first, middle,
277
+ last, std::forward<Compare>(comp));
278
+ }
279
+
280
+ _segments = (_segments + 1 ) / 2 ;
281
+
282
+ std::swap (chunks_ind, chunks_ind2);
283
+ chunks_ind2.clear ();
284
+ }
285
+ } // End of omp parallel region
252
286
253
287
// Copy the results into the output.
254
288
0 commit comments