Skip to content

Commit 060b37e

Browse files
authored
[SYCLomatic] Refine migration for thrust::max and thrust::min and add them to api-query-mapping (#2923)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent bd7920c commit 060b37e

File tree

6 files changed

+145
-4
lines changed

6 files changed

+145
-4
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <thrust/extrema.h>
2+
3+
void max_test() {
4+
// clang-format off
5+
// Start
6+
struct key_value {
7+
int key;
8+
int value;
9+
};
10+
struct compare_key_value {
11+
__host__ __device__ bool operator()(key_value lhs, key_value rhs) {
12+
return lhs.key < rhs.key;
13+
}
14+
};
15+
key_value a = {13, 0};
16+
key_value b = {7, 1};
17+
key_value smaller = thrust::max(a, b, compare_key_value());
18+
int value = thrust::max(1, 2);
19+
// End
20+
// clang-format on
21+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <thrust/extrema.h>
2+
3+
void min_test() {
4+
// clang-format off
5+
// Start
6+
struct key_value {
7+
int key;
8+
int value;
9+
};
10+
struct compare_key_value {
11+
__host__ __device__ bool operator()(key_value lhs, key_value rhs) {
12+
return lhs.key < rhs.key;
13+
}
14+
};
15+
key_value a = {13, 0};
16+
key_value b = {7, 1};
17+
key_value smaller = thrust::min(a, b, compare_key_value());
18+
int value = thrust::min(1, 2);
19+
// End
20+
// clang-format on
21+
}

clang/lib/DPCT/RulesLangLib/APINamesThrust.inc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,10 +1172,18 @@ thrustFactory("thrust::detail::vector_equal",
11721172
{{3,PolicyState::NoPolicy,3,"oneapi::dpl::equal", HelperFeatureEnum::none}}),
11731173

11741174
// thrust::max
1175-
CALL_FACTORY_ENTRY("thrust::max", CALL("std::max", ARG(0), ARG(1)))
1175+
CONDITIONAL_FACTORY_ENTRY(
1176+
CheckArgCount(2),
1177+
CALL_FACTORY_ENTRY("thrust::max", CALL("std::max", ARG(0), ARG(1))),
1178+
CALL_FACTORY_ENTRY("thrust::max", CALL("std::max", ARG(0), ARG(1), ARG(2)))
1179+
)
11761180

11771181
// thrust::min
1178-
CALL_FACTORY_ENTRY("thrust::min", CALL("std::min",ARG(0), ARG(1)))
1182+
CONDITIONAL_FACTORY_ENTRY(
1183+
CheckArgCount(2),
1184+
CALL_FACTORY_ENTRY("thrust::min", CALL("std::min", ARG(0), ARG(1))),
1185+
CALL_FACTORY_ENTRY("thrust::min", CALL("std::min", ARG(0), ARG(1), ARG(2)))
1186+
)
11791187

11801188
// thrust::tie
11811189
CALL_FACTORY_ENTRY("thrust::tie", CALL("std::tie",ARG(0), ARG(1)))

clang/test/dpct/query_api_mapping/Thrust/thrust_api_test_p3.cu

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,4 +302,62 @@
302302
// thrust_get_temporary_buffer-NEXT: dpct::device_sys_tag device_sys;
303303
// thrust_get_temporary_buffer-NEXT: ptr_and_size_t ptr_and_size = dpct::get_temporary_allocation<int>(device_sys, N);
304304

305-
305+
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=thrust::max --extra-arg="-std=c++14"| FileCheck %s -check-prefix=thrust_max
306+
// thrust_max: CUDA API:
307+
// thrust_max-NEXT: struct key_value {
308+
// thrust_max-NEXT: int key;
309+
// thrust_max-NEXT: int value;
310+
// thrust_max-NEXT: };
311+
// thrust_max-NEXT: struct compare_key_value {
312+
// thrust_max-NEXT: __host__ __device__ bool operator()(key_value lhs, key_value rhs) {
313+
// thrust_max-NEXT: return lhs.key < rhs.key;
314+
// thrust_max-NEXT: }
315+
// thrust_max-NEXT: };
316+
// thrust_max-NEXT: key_value a = {13, 0};
317+
// thrust_max-NEXT: key_value b = {7, 1};
318+
// thrust_max-NEXT: key_value smaller = thrust::max(a, b, compare_key_value());
319+
// thrust_max-NEXT: int value = thrust::max(1, 2);
320+
// thrust_max-NEXT: Is migrated to:
321+
// thrust_max-NEXT: struct key_value {
322+
// thrust_max-NEXT: int key;
323+
// thrust_max-NEXT: int value;
324+
// thrust_max-NEXT: };
325+
// thrust_max-NEXT: struct compare_key_value {
326+
// thrust_max-NEXT: bool operator()(key_value lhs, key_value rhs) {
327+
// thrust_max-NEXT: return lhs.key < rhs.key;
328+
// thrust_max-NEXT: }
329+
// thrust_max-NEXT: };
330+
// thrust_max-NEXT: key_value a = {13, 0};
331+
// thrust_max-NEXT: key_value b = {7, 1};
332+
// thrust_max-NEXT: key_value smaller = std::max(a, b, compare_key_value());
333+
// thrust_max-NEXT: int value = std::max(1, 2);
334+
335+
// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=thrust::min --extra-arg="-std=c++14"| FileCheck %s -check-prefix=thrust_min
336+
// thrust_min: CUDA API:
337+
// thrust_min-NEXT: struct key_value {
338+
// thrust_min-NEXT: int key;
339+
// thrust_min-NEXT: int value;
340+
// thrust_min-NEXT: };
341+
// thrust_min-NEXT: struct compare_key_value {
342+
// thrust_min-NEXT: __host__ __device__ bool operator()(key_value lhs, key_value rhs) {
343+
// thrust_min-NEXT: return lhs.key < rhs.key;
344+
// thrust_min-NEXT: }
345+
// thrust_min-NEXT: };
346+
// thrust_min-NEXT: key_value a = {13, 0};
347+
// thrust_min-NEXT: key_value b = {7, 1};
348+
// thrust_min-NEXT: key_value smaller = thrust::min(a, b, compare_key_value());
349+
// thrust_min-NEXT: int value = thrust::min(1, 2);
350+
// thrust_min-NEXT: Is migrated to:
351+
// thrust_min-NEXT: struct key_value {
352+
// thrust_min-NEXT: int key;
353+
// thrust_min-NEXT: int value;
354+
// thrust_min-NEXT: };
355+
// thrust_min-NEXT: struct compare_key_value {
356+
// thrust_min-NEXT: bool operator()(key_value lhs, key_value rhs) {
357+
// thrust_min-NEXT: return lhs.key < rhs.key;
358+
// thrust_min-NEXT: }
359+
// thrust_min-NEXT: };
360+
// thrust_min-NEXT: key_value a = {13, 0};
361+
// thrust_min-NEXT: key_value b = {7, 1};
362+
// thrust_min-NEXT: key_value smaller = std::min(a, b, compare_key_value());
363+
// thrust_min-NEXT: int value = std::min(1, 2);

clang/test/dpct/query_api_mapping/test_all.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,9 +2476,11 @@
24762476
// CHECK-NEXT: thrust::make_tuple
24772477
// CHECK-NEXT: thrust::make_zip_iterator
24782478
// CHECK-NEXT: thrust::malloc
2479+
// CHECK-NEXT: thrust::max
24792480
// CHECK-NEXT: thrust::max_element
24802481
// CHECK-NEXT: thrust::merge
24812482
// CHECK-NEXT: thrust::merge_by_key
2483+
// CHECK-NEXT: thrust::min
24822484
// CHECK-NEXT: thrust::min_element
24832485
// CHECK-NEXT: thrust::minmax_element
24842486
// CHECK-NEXT: thrust::mismatch

clang/test/dpct/thrust-for-h2o4gpu.cu

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <thrust/execution_policy.h>
1515
#include <thrust/random.h>
1616
#include <thrust/reduce.h>
17-
#include <algorithm>
17+
#include <thrust/extrema.h>
1818
#include <thrust/inner_product.h>
1919
#include <thrust/extrema.h>
2020
#include <thrust/host_vector.h>
@@ -651,3 +651,34 @@ template <class T> auto foo2(T t) {
651651
// CHECK: return dpct::make_constant_iterator(std::make_tuple(t, false));
652652
return thrust::make_constant_iterator(thrust::make_tuple<T, bool>(t, false));
653653
}
654+
655+
struct key_value {
656+
int key;
657+
int value;
658+
};
659+
660+
// CHECK: struct compare_key_value {
661+
// CHECK-NEXT: bool operator()(key_value lhs, key_value rhs) {
662+
// CHECK-NEXT: return lhs.key < rhs.key;
663+
// CHECK-NEXT: }
664+
// CHECK-NEXT: };
665+
struct compare_key_value {
666+
__host__ __device__ bool operator()(key_value lhs, key_value rhs) {
667+
return lhs.key < rhs.key;
668+
}
669+
};
670+
671+
void thrust_max_min() {
672+
key_value a = {13, 0};
673+
key_value b = {7, 1};
674+
675+
// CHECK: key_value smaller = std::min(a, b, compare_key_value());
676+
// CHECK-NEXT: key_value maxer = std::max(a, b, compare_key_value());
677+
key_value smaller = thrust::min(a, b, compare_key_value());
678+
key_value maxer = thrust::max(a, b, compare_key_value());
679+
680+
// CHECK: int min = std::min(1, 2);
681+
// CHECK-NEXT: int max = std::max(1, 2);
682+
int min = thrust::min(1, 2);
683+
int max = thrust::max(1, 2);
684+
}

0 commit comments

Comments
 (0)