Skip to content

Commit b5dadfc

Browse files
authored
merge instead 2nd sort in shp::sort (#740)
1 parent 2ecbcbc commit b5dadfc

File tree

1 file changed

+46
-12
lines changed

1 file changed

+46
-12
lines changed

include/dr/shp/algorithms/sort.hpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <dr/concepts/concepts.hpp>
1313
#include <dr/detail/onedpl_direct_iterator.hpp>
1414
#include <dr/shp/init.hpp>
15+
16+
#include <omp.h>
1517
#include <sycl/sycl.hpp>
1618

1719
namespace dr::shp {
@@ -151,7 +153,7 @@ void sort(R &&r, Compare comp = Compare()) {
151153
auto &&local_segment = dr::shp::__detail::local(segment);
152154

153155
std::size_t *splitter_i = sycl::malloc_shared<std::size_t>(
154-
n_splitters, q.get_device(), shp::context());
156+
n_splitters + 1, q.get_device(), shp::context());
155157
splitter_indices.push_back(splitter_i);
156158

157159
// Local copy `medians_l` necessary due to [GSD-3893]
@@ -166,6 +168,8 @@ void sort(R &&r, Compare comp = Compare()) {
166168

167169
sycl::free(medians_l, shp::context());
168170

171+
splitter_i[n_splitters] = rng::size(local_segment);
172+
169173
auto p_first = rng::begin(local_segment);
170174
auto p_last = p_first;
171175
for (std::size_t i = 0; i < n_splitters; i++) {
@@ -235,20 +239,50 @@ void sort(R &&r, Compare comp = Compare()) {
235239
dr::shp::__detail::wait(events);
236240
events.clear();
237241

238-
// Sort each of these new segments
239-
for (std::size_t i = 0; i < sorted_segments.size(); i++) {
240-
auto &&local_policy =
241-
dr::shp::__detail::dpl_policy(dr::ranges::rank(segments[i]));
242-
T *seg = sorted_segments[i];
243-
std::size_t n_elements = sorted_seg_sizes[i];
242+
// merge sorted chunks within each of these new segments
244243

245-
auto e = __detail::sort_async(local_policy, seg, seg + n_elements, comp);
244+
#pragma omp parallel num_threads(n_segments)
245+
{
246+
int t = omp_get_thread_num();
246247

247-
events.push_back(e);
248-
}
248+
std::vector<std::size_t> chunks_ind, chunks_ind2;
249+
chunks_ind.push_back(0);
249250

250-
dr::shp::__detail::wait(events);
251-
events.clear();
251+
std::size_t v = 0;
252+
for (std::size_t i = 0; i < n_segments; i++) {
253+
v += (t == 0) ? splitter_indices[i][0]
254+
: splitter_indices[i][t] - splitter_indices[i][t - 1];
255+
chunks_ind.push_back(v);
256+
}
257+
258+
auto _segments = n_segments;
259+
while (_segments > 1) {
260+
chunks_ind2.push_back(0);
261+
262+
for (int s = 0; s < _segments / 2; s++) {
263+
264+
std::size_t l = (2 * s + 2 < _segments) ? chunks_ind[2 * s + 2]
265+
: sorted_seg_sizes[t];
266+
267+
auto first = dr::__detail::direct_iterator(sorted_segments[t] +
268+
chunks_ind[2 * s]);
269+
auto middle = dr::__detail::direct_iterator(sorted_segments[t] +
270+
chunks_ind[2 * s + 1]);
271+
auto last = dr::__detail::direct_iterator(sorted_segments[t] + l);
272+
273+
chunks_ind2.push_back(l);
274+
275+
oneapi::dpl::inplace_merge(
276+
__detail::dpl_policy(dr::ranges::rank(segments[t])), first, middle,
277+
last, std::forward<Compare>(comp));
278+
}
279+
280+
_segments = (_segments + 1) / 2;
281+
282+
std::swap(chunks_ind, chunks_ind2);
283+
chunks_ind2.clear();
284+
}
285+
} // End of omp parallel region
252286

253287
// Copy the results into the output.
254288

0 commit comments

Comments
 (0)