Skip to content

Commit 7a28a5b

Browse files
authored
[SYCLomatic] For math function used in template function, only migrate the template version when there is no instantiation (#2779)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent 3eb2338 commit 7a28a5b

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

clang/lib/DPCT/RulesLang/RulesLang.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7514,11 +7514,20 @@ void MathFunctionsRule::registerMatcher(MatchFinder &MF) {
75147514
}
75157515

75167516
void MathFunctionsRule::runRule(const MatchFinder::MatchResult &Result) {
7517-
const CallExpr *CE = getAssistNodeAsType<CallExpr>(Result, "math");
7518-
if (!CE)
7519-
CE = getNodeAsType<CallExpr>(Result, "unresolved");
7520-
if (!CE)
7521-
return;
7517+
const CallExpr *CE = getAssistNodeAsType<CallExpr>(Result, "math");
7518+
bool IsUnresolved = false;
7519+
if (!CE) {
7520+
CE = getNodeAsType<CallExpr>(Result, "unresolved");
7521+
IsUnresolved = true;
7522+
}
7523+
if (!CE)
7524+
return;
7525+
if (IsUnresolved) {
7526+
const auto *FTD = DpctGlobalInfo::findAncestor<FunctionTemplateDecl>(CE);
7527+
if (FTD && (FTD->spec_begin() != FTD->spec_end())) {
7528+
return;
7529+
}
7530+
}
75227531

75237532
ExprAnalysis EA(CE);
75247533
emplaceTransformation(EA.getReplacement());

clang/test/dpct/math-functions.cu

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -663,17 +663,32 @@ __device__ A min(A a, A b) { return a; }
663663
__device__ A max(A a, A b) { return a; }
664664

665665
template <class T> __device__ T clamp(T x, T a, T b) {
666-
// CHECK: /*
667-
// CHECK-NEXT: DPCT1064:{{[0-9]+}}: Migrated min call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code.
668-
// CHECK-NEXT: */
669-
// CHECK-NEXT: /*
670-
// CHECK-NEXT: DPCT1064:{{[0-9]+}}: Migrated max call is used in a macro/template definition and may not be valid for all macro/template uses. Adjust the code.
671-
// CHECK-NEXT: */
672-
// CHECK-NEXT: return dpct::min(dpct::max(x, a), b);
666+
// CHECK: return min(max(x, a), b);
673667
return min(max(x, a), b);
674668
}
675669

676670
__global__ void kernel_2() {
677671
A a;
678672
clamp(a, a, a);
679673
}
674+
675+
template <class T> struct InternalType;
676+
template <> struct InternalType<float> {
677+
typedef float scalar_t;
678+
typedef float2 vec2_t;
679+
typedef float4 vec4_t;
680+
__device__ __forceinline__ static vec4_t zero_vec4(void) {
681+
return make_float4(0, 0, 0, 0);
682+
}
683+
};
684+
685+
template <class T> __global__ void foo5() {
686+
typedef typename InternalType<T>::vec4_t vec4_t;
687+
vec4_t v = InternalType<T>::zero_vec4();
688+
// CHECK: sycl::fabs(v.x());
689+
fabsf(v.x);
690+
}
691+
692+
void foo6() {
693+
foo5<float><<<1, 1>>>();
694+
}

0 commit comments

Comments
 (0)