@@ -30,6 +30,28 @@ __global__ void Sort(int *data) {
30
30
BlockStore (temp_storage_store).Store (data, thread_keys);
31
31
}
32
32
33
+ __global__ void SortHalf (__half *data) {
34
+ // CHECK: using BlockRadixSort = dpct::group::group_radix_sort<sycl::half, 4>;
35
+ // CHECK-NEXT: using BlockLoad = dpct::group::group_load<sycl::half, 4>;
36
+ // CHECK-NEXT: using BlockStore = dpct::group::group_store<sycl::half, 4>;
37
+ // CHECK-NOT: __shared__ typename BlockLoad::TempStorage temp_storage_load;
38
+ // CHECK-NOT: __shared__ typename BlockStore::TempStorage temp_storage_store;
39
+ // CHECK-NOT: __shared__ typename BlockRadixSort::TempStorage temp_storage;
40
+ using BlockRadixSort = cub::BlockRadixSort<__half, 128 , 4 >;
41
+ using BlockLoad = cub::BlockLoad<__half, 128 , 4 >;
42
+ using BlockStore = cub::BlockStore<__half, 128 , 4 >;
43
+ __shared__ typename BlockLoad::TempStorage temp_storage_load;
44
+ __shared__ typename BlockStore::TempStorage temp_storage_store;
45
+ __shared__ typename BlockRadixSort::TempStorage temp_storage;
46
+ __half thread_keys[4 ];
47
+ // CHECK: BlockLoad(temp_storage_load).load(item_ct1, data, thread_keys);
48
+ // CHECK-NEXT: BlockRadixSort(temp_storage).sort(item_ct1, thread_keys);
49
+ // CHECK-NEXT: BlockStore(temp_storage_store).store(item_ct1, data, thread_keys);
50
+ BlockLoad (temp_storage_load).Load (data, thread_keys);
51
+ BlockRadixSort (temp_storage).Sort (thread_keys);
52
+ BlockStore (temp_storage_store).Store (data, thread_keys);
53
+ }
54
+
33
55
__global__ void SortDescending (int *data) {
34
56
// CHECK: using BlockRadixSort = dpct::group::group_radix_sort<int, 4>;
35
57
// CHECK-NEXT: using BlockLoad = dpct::group::group_load<int, 4, dpct::group::group_load_algorithm::blocked>;
@@ -171,8 +193,9 @@ __global__ void test_unsupported(int *data) {
171
193
172
194
template <typename T, int N>
173
195
void print_array (T (&arr)[N]) {
174
- for (int i = 0 ; i < N; ++i)
175
- printf (" %d%c" , arr[i], (i == N - 1 ? ' \n ' : ' ,' ));
196
+ for (int i = 0 ; i < N; ++i) {
197
+ std::cout << (int )arr[i] << (i == N - 1 ? ' \n ' : ' ,' );
198
+ }
176
199
}
177
200
178
201
bool test_sort () {
@@ -211,6 +234,42 @@ bool test_sort() {
211
234
return true ;
212
235
}
213
236
237
+ bool test_sorthalf () {
238
+ __half data[512 ] = {0 }, *d_data = nullptr ;
239
+ cudaMalloc (&d_data, sizeof (__half) * 512 );
240
+ for (int i = 0 , x = 0 , y = 511 ; i < 128 ; ++i) {
241
+ data[i * 4 + 0 ] = x++;
242
+ data[i * 4 + 1 ] = y--;
243
+ data[i * 4 + 2 ] = x++;
244
+ data[i * 4 + 3 ] = y--;
245
+ }
246
+ cudaMemcpy (d_data, data, sizeof (data), cudaMemcpyHostToDevice);
247
+ // CHECK: q_ct1.submit(
248
+ // CHECK-NEXT: [&](sycl::handler &cgh) {
249
+ // CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_load_acc(dpct::group::group_load<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
250
+ // CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_store_acc(dpct::group::group_store<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
251
+ // CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_acc(dpct::group::group_radix_sort<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
252
+ // CHECK-EMPTY:
253
+ // CHECK-NEXT: cgh.parallel_for(
254
+ // CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
255
+ // CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
256
+ // CHECK-NEXT: SortHalf(d_data, &temp_storage_load_acc[0], &temp_storage_store_acc[0], &temp_storage_acc[0]);
257
+ // CHECK-NEXT: });
258
+ // CHECK-NEXT: });
259
+ SortHalf<<<1 , 128 >>> (d_data);
260
+ cudaDeviceSynchronize ();
261
+ cudaMemcpy (data, d_data, sizeof (data), cudaMemcpyDeviceToHost);
262
+ cudaFree (d_data);
263
+ for (int i = 0 ; i < 512 ; ++i)
264
+ if ((int )data[i] != i) {
265
+ printf (" test_sorthalf failed\n " );
266
+ print_array (data);
267
+ return false ;
268
+ }
269
+ printf (" test_sorthalf pass\n " );
270
+ return true ;
271
+ }
272
+
214
273
bool test_sort_descending () {
215
274
int data[512 ] = {0 }, *d_data = nullptr ;
216
275
cudaMalloc (&d_data, sizeof (int ) * 512 );
@@ -610,7 +669,7 @@ bool test_sort_descending_blocked_to_striped_bit() {
610
669
}
611
670
612
671
int main () {
613
- return !(test_sort () && test_sort_descending () &&
672
+ return !(test_sort () && test_sorthalf () && test_sort_descending () &&
614
673
test_sort_blocked_to_striped () &&
615
674
test_sort_descending_blocked_to_striped () && test_sort_bit () &&
616
675
test_sort_descending_bit () && test_sort_blocked_to_striped_bit () &&
0 commit comments