Skip to content

Commit 40c46e9

Browse files
Merge replacing 2nd sort in mhp::sort() (#731)
* merge instead of second sort in mhp --------- Co-authored-by: Łukasz Ślusarczyk <lukasz.slusarczyk@intel.com>
1 parent 368cd3d commit 40c46e9

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

include/dr/mhp/algorithms/sort.hpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,44 @@ template <typename R, typename Compare> void local_sort(R &r, Compare &&comp) {
136136
}
137137
}
138138

139+
template <typename T, typename Compare>
140+
void local_merge(buffer<T> &v, std::vector<std::size_t> chunks,
141+
Compare &&comp) {
142+
143+
std::exclusive_scan(chunks.begin(), chunks.end(), chunks.begin(), 0);
144+
145+
while (chunks.size() > 1) {
146+
std::size_t segno = chunks.size();
147+
std::vector<std::size_t> next_chunks;
148+
for (std::size_t i = 0; i < segno / 2; i++) {
149+
auto first = v.begin() + chunks[2 * i];
150+
auto middle = v.begin() + chunks[2 * i + 1];
151+
auto last = (2 * i + 2 < segno) ? v.begin() + chunks[2 * i + 2] : v.end();
152+
if (mhp::use_sycl()) {
153+
#ifdef SYCL_LANGUAGE_VERSION
154+
auto dfirst = dr::__detail::direct_iterator(first);
155+
auto dmiddle = dr::__detail::direct_iterator(middle);
156+
auto dlast = dr::__detail::direct_iterator(last);
157+
oneapi::dpl::inplace_merge(dpl_policy(), dfirst, dmiddle, dlast, comp);
158+
#else
159+
assert(false);
160+
#endif
161+
} else {
162+
std::inplace_merge(first, middle, last, comp);
163+
}
164+
next_chunks.push_back(chunks[2 * i]);
165+
}
166+
if (segno % 2 == 1) {
167+
next_chunks.push_back(chunks[segno - 1]);
168+
}
169+
std::swap(chunks, next_chunks);
170+
}
171+
}
172+
139173
template <typename Compare>
140-
void _find_split_idx(std::size_t &vidx, std::size_t &segidx, Compare &&comp,
141-
auto &ls, auto &vec_v, auto &vec_i, auto &vec_s) {
174+
void _find_split_idx(std::size_t &vidx, Compare &&comp, auto &ls, auto &vec_v,
175+
auto &vec_i, auto &vec_s) {
176+
std::size_t segidx = 0;
142177
while (vidx < default_comm().size() && segidx < rng::size(ls)) {
143178
if (comp(vec_v[vidx - 1], ls[segidx])) {
144179
vec_i[vidx] = segidx;
@@ -205,7 +240,7 @@ void splitters(Seg &lsegment, Compare &&comp,
205240
vec_split_v[_i] = vec_gmedians[global_median_idx];
206241
}
207242

208-
std::size_t segidx = 0, vidx = 1;
243+
std::size_t vidx = 1;
209244

210245
/* The while loop is executed in host memory, and together with
211246
* sycl_copy takes most of the execution time of the sort procedure */
@@ -215,13 +250,13 @@ void splitters(Seg &lsegment, Compare &&comp,
215250
sycl_copy(rng::data(lsegment), rng::data(vec_lseg_tmp),
216251
rng::size(lsegment));
217252

218-
_find_split_idx(vidx, segidx, comp, vec_lseg_tmp, vec_split_v, vec_split_i,
253+
_find_split_idx(vidx, comp, vec_lseg_tmp, vec_split_v, vec_split_i,
219254
vec_split_s);
220255
#else
221256
assert(false);
222257
#endif
223258
} else {
224-
_find_split_idx(vidx, segidx, comp, lsegment, vec_split_v, vec_split_i,
259+
_find_split_idx(vidx, comp, lsegment, vec_split_v, vec_split_i,
225260
vec_split_s);
226261
}
227262

@@ -392,9 +427,8 @@ void dist_sort(R &r, Compare &&comp) {
392427
default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata,
393428
vec_rsizes, vec_rindices);
394429

395-
/* TODO: vec recvdata is partially sorted, implementation of merge on GPU is
396-
* desirable */
397-
__detail::local_sort(vec_recvdata, comp);
430+
__detail::local_merge(vec_recvdata, vec_rsizes, comp);
431+
398432
// MPI_Wait(&req_recvelems, MPI_STATUS_IGNORE);
399433

400434
_total_elems = std::reduce(vec_recv_elems.begin(), vec_recv_elems.end());

test/gtest/mhp/CMakeLists.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ add_executable(
5656

5757
add_executable(mhp-quick-test
5858
mhp-tests.cpp
59-
halo.cpp
59+
mhpsort.cpp
60+
../common/sort.cpp
6061
)
6162
# cmake-format: on
6263

@@ -104,12 +105,11 @@ if(ENABLE_SYCL)
104105
${sycl-exclusions}Halo3/*:Sort*:Counted/*:Mdspan*:Mdarray*:)
105106
endif()
106107

107-
add_mhp_ctest(NAME mhp-quick-test NPROC 1 SYCL)
108-
add_mhp_ctest(NAME mhp-quick-test NPROC 2 SYCL)
109-
add_mhp_ctest(
110-
NAME mhp-quick-test NPROC 1 OFFLOAD SYCL TARGS --device-memory)
111-
add_mhp_ctest(
112-
NAME mhp-quick-test NPROC 2 OFFLOAD SYCL TARGS --device-memory)
108+
foreach(nproc RANGE 1 4)
109+
add_mhp_ctest(NAME mhp-quick-test NPROC ${nproc} SYCL)
110+
add_mhp_ctest(
111+
NAME mhp-quick-test NPROC ${nproc} OFFLOAD SYCL TARGS --device-memory)
112+
endforeach()
113113

114114
add_mhp_ctest(
115115
NAME mhp-tests NPROC 2 TIMEOUT 150 OFFLOAD SYCL TARGS --device-memory

0 commit comments

Comments
 (0)