Skip to content

Commit 99990ed

Browse files
authored
[SYCLomatic] Fix the issue that __half type not migrated during the migration of BlockRadixSort API (#2797)
Signed-off-by: intwanghao <hao3.wang@intel.com>
1 parent e9b108d commit 99990ed

File tree

4 files changed

+85
-12
lines changed

4 files changed

+85
-12
lines changed

clang/lib/DPCT/RulesLangLib/CUBAPIMigration.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,11 +285,10 @@ void CubMemberCallRule::runRule(
285285
CanTy->getAs<RecordType>()->getDecl());
286286
const auto &ValueTyArg = ClassSpecDecl->getTemplateArgs()[0];
287287

288-
ValueTyArg.getAsType().getAsString();
289288
std::string Fn;
290289
llvm::raw_string_ostream OS(Fn);
291290
OS << MapNames::getDpctNamespace() << "group::" << HelpFuncName << "<"
292-
<< ValueTyArg.getAsType().getAsString();
291+
<< DpctGlobalInfo::getReplacedTypeName(ValueTyArg.getAsType());
293292
if (isBlockShuffle) {
294293
if (!ClassSpecDecl->getTemplateArgs()[1].getIsDefaulted()) {
295294
OS << ", " << ClassSpecDecl->getTemplateArgs()[1].getAsIntegral();

clang/runtime/dpct-rt/include/dpct/detail/group_utils_detail.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ template <int RADIX_BITS, bool DESCENDING = false> class radix_rank {
3636

3737
radix_rank(uint8_t *local_memory) : _local_memory(local_memory) {}
3838

39-
template <typename Item, int VALUES_PER_THREAD>
39+
template <typename Item, typename KT, int VALUES_PER_THREAD>
4040
__dpct_inline__ void
41-
rank_keys(const Item &item, uint32_t (&keys)[VALUES_PER_THREAD],
41+
rank_keys(const Item &item, KT (&keys)[VALUES_PER_THREAD],
4242
int (&ranks)[VALUES_PER_THREAD], int current_bit, int num_bits) {
4343

4444
digit_counter_type thread_prefixes[VALUES_PER_THREAD];
@@ -204,10 +204,23 @@ template <typename U> struct base_traits<float, U> {
204204
}
205205
};
206206

207+
template <typename U> struct base_traits<sycl::half, U> {
208+
static constexpr U HIGH_BIT = U(1) << ((sizeof(U) * 8) - 1);
209+
static __dpct_inline__ U twiddle_in(U key) {
210+
U mask = (key & HIGH_BIT) ? U(-1) : HIGH_BIT;
211+
return key ^ mask;
212+
}
213+
static __dpct_inline__ U twiddle_out(U key) {
214+
U mask = (key & HIGH_BIT) ? HIGH_BIT : U(-1);
215+
return key ^ mask;
216+
}
217+
};
218+
207219
template <typename T> struct traits : base_traits<T, T> {};
208220
template <> struct traits<uint32_t> : base_traits<uint32_t, uint32_t> {};
209221
template <> struct traits<int> : base_traits<int, uint32_t> {};
210222
template <> struct traits<float> : base_traits<float, uint32_t> {};
223+
template <> struct traits<sycl::half> : base_traits<sycl::half, uint16_t> {};
211224

212225
template <int N> struct power_of_two {
213226
enum { VALUE = ((N & (N - 1)) == 0) };

clang/runtime/dpct-rt/include/dpct/group_utils.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,11 @@ class group_radix_sort {
365365
helper_sort(const Item &item, T (&keys)[ElementsPerWorkItem],
366366
int begin_bit = 0, int end_bit = 8 * sizeof(T),
367367
bool is_striped = false) {
368-
369-
uint32_t(&unsigned_keys)[ElementsPerWorkItem] =
370-
reinterpret_cast<uint32_t(&)[ElementsPerWorkItem]>(keys);
368+
using UnsignedT =
369+
typename std::conditional<std::is_same<T, sycl::half>::value, uint16_t,
370+
uint32_t>::type;
371+
UnsignedT(&unsigned_keys)[ElementsPerWorkItem] =
372+
reinterpret_cast<UnsignedT(&)[ElementsPerWorkItem]>(keys);
371373

372374
#pragma unroll
373375
for (int i = 0; i < ElementsPerWorkItem; ++i) {
@@ -379,8 +381,8 @@ class group_radix_sort {
379381

380382
int ranks[ElementsPerWorkItem];
381383
detail::radix_rank<RADIX_BITS, DESCENDING>(_local_memory)
382-
.template rank_keys<Item, ElementsPerWorkItem>(item, unsigned_keys,
383-
ranks, i, pass_bits);
384+
.template rank_keys<Item, UnsignedT, ElementsPerWorkItem>(
385+
item, unsigned_keys, ranks, i, pass_bits);
384386

385387
sycl::group_barrier(item.get_group());
386388

clang/test/dpct/cub/blocklevel/blockradixsort.cu

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@ __global__ void Sort(int *data) {
3030
BlockStore(temp_storage_store).Store(data, thread_keys);
3131
}
3232

33+
__global__ void SortHalf(__half *data) {
34+
// CHECK: using BlockRadixSort = dpct::group::group_radix_sort<sycl::half, 4>;
35+
// CHECK-NEXT: using BlockLoad = dpct::group::group_load<sycl::half, 4>;
36+
// CHECK-NEXT: using BlockStore = dpct::group::group_store<sycl::half, 4>;
37+
// CHECK-NOT: __shared__ typename BlockLoad::TempStorage temp_storage_load;
38+
// CHECK-NOT: __shared__ typename BlockStore::TempStorage temp_storage_store;
39+
// CHECK-NOT: __shared__ typename BlockRadixSort::TempStorage temp_storage;
40+
using BlockRadixSort = cub::BlockRadixSort<__half, 128, 4>;
41+
using BlockLoad = cub::BlockLoad<__half, 128, 4>;
42+
using BlockStore = cub::BlockStore<__half, 128, 4>;
43+
__shared__ typename BlockLoad::TempStorage temp_storage_load;
44+
__shared__ typename BlockStore::TempStorage temp_storage_store;
45+
__shared__ typename BlockRadixSort::TempStorage temp_storage;
46+
__half thread_keys[4];
47+
// CHECK: BlockLoad(temp_storage_load).load(item_ct1, data, thread_keys);
48+
// CHECK-NEXT: BlockRadixSort(temp_storage).sort(item_ct1, thread_keys);
49+
// CHECK-NEXT: BlockStore(temp_storage_store).store(item_ct1, data, thread_keys);
50+
BlockLoad(temp_storage_load).Load(data, thread_keys);
51+
BlockRadixSort(temp_storage).Sort(thread_keys);
52+
BlockStore(temp_storage_store).Store(data, thread_keys);
53+
}
54+
3355
__global__ void SortDescending(int *data) {
3456
// CHECK: using BlockRadixSort = dpct::group::group_radix_sort<int, 4>;
3557
// CHECK-NEXT: using BlockLoad = dpct::group::group_load<int, 4, dpct::group::group_load_algorithm::blocked>;
@@ -171,8 +193,9 @@ __global__ void test_unsupported(int *data) {
171193

172194
template <typename T, int N>
173195
void print_array(T (&arr)[N]) {
174-
for (int i = 0; i < N; ++i)
175-
printf("%d%c", arr[i], (i == N - 1 ? '\n' : ','));
196+
for (int i = 0; i < N; ++i) {
197+
std::cout << (int)arr[i] << (i == N - 1 ? '\n' : ',');
198+
}
176199
}
177200

178201
bool test_sort() {
@@ -211,6 +234,42 @@ bool test_sort() {
211234
return true;
212235
}
213236

237+
bool test_sorthalf() {
238+
__half data[512] = {0}, *d_data = nullptr;
239+
cudaMalloc(&d_data, sizeof(__half) * 512);
240+
for (int i = 0, x = 0, y = 511; i < 128; ++i) {
241+
data[i * 4 + 0] = x++;
242+
data[i * 4 + 1] = y--;
243+
data[i * 4 + 2] = x++;
244+
data[i * 4 + 3] = y--;
245+
}
246+
cudaMemcpy(d_data, data, sizeof(data), cudaMemcpyHostToDevice);
247+
// CHECK: q_ct1.submit(
248+
// CHECK-NEXT: [&](sycl::handler &cgh) {
249+
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_load_acc(dpct::group::group_load<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
250+
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_store_acc(dpct::group::group_store<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
251+
// CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_acc(dpct::group::group_radix_sort<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
252+
// CHECK-EMPTY:
253+
// CHECK-NEXT: cgh.parallel_for(
254+
// CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
255+
// CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
256+
// CHECK-NEXT: SortHalf(d_data, &temp_storage_load_acc[0], &temp_storage_store_acc[0], &temp_storage_acc[0]);
257+
// CHECK-NEXT: });
258+
// CHECK-NEXT: });
259+
SortHalf<<<1, 128>>>(d_data);
260+
cudaDeviceSynchronize();
261+
cudaMemcpy(data, d_data, sizeof(data), cudaMemcpyDeviceToHost);
262+
cudaFree(d_data);
263+
for (int i = 0; i < 512; ++i)
264+
if ((int)data[i] != i) {
265+
printf("test_sorthalf failed\n");
266+
print_array(data);
267+
return false;
268+
}
269+
printf("test_sorthalf pass\n");
270+
return true;
271+
}
272+
214273
bool test_sort_descending() {
215274
int data[512] = {0}, *d_data = nullptr;
216275
cudaMalloc(&d_data, sizeof(int) * 512);
@@ -610,7 +669,7 @@ bool test_sort_descending_blocked_to_striped_bit() {
610669
}
611670

612671
int main() {
613-
return !(test_sort() && test_sort_descending() &&
672+
return !(test_sort() && test_sorthalf() && test_sort_descending() &&
614673
test_sort_blocked_to_striped() &&
615674
test_sort_descending_blocked_to_striped() && test_sort_bit() &&
616675
test_sort_descending_bit() && test_sort_blocked_to_striped_bit() &&

0 commit comments

Comments
 (0)