Skip to content

Commit 46798de

Browse files
authored
[SYCLomatic] Fix __ldg used in macro (#2706)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent b27a756 commit 46798de

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,24 @@ AddrOfExpr::AddrOfExpr(const Expr *E, const CallExpr *C) {
6363
}
6464

6565
DerefExpr::DerefExpr(const Expr *E, const CallExpr *C) {
66+
const auto &SM = DpctGlobalInfo::getSourceManager();
6667
this->C = C;
6768
// If E is UnaryOperator or CXXOperatorCallExpr D.E will has value
6869
this->E = getDereferencedExpr(E);
6970
if (this->E) {
71+
if (C) {
72+
// Check the addrof symbol (&) is in the parent range since only it
73+
// will be merged with the deref symbol (*)
74+
if (E->getBeginLoc().isMacroID()) {
75+
auto Range = getDefinitionRange(C->getBeginLoc(), C->getEndLoc());
76+
if (!isInRange(Range.getBegin(), Range.getEnd(),
77+
SM.getSpellingLoc(E->getBeginLoc()))) {
78+
this->E = E;
79+
this->NeedParens = true;
80+
return;
81+
}
82+
}
83+
}
7084
this->E = this->E->IgnoreParens();
7185
this->AddrOfRemoved = true;
7286
} else {

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,8 @@ class DerefExpr {
800800

801801
public:
802802
DerefExpr(const Expr *E, const CallExpr *C = nullptr);
803+
DerefExpr(std::pair<const CallExpr *, const Expr *> P)
804+
: DerefExpr(P.second, P.first) {}
803805
template <class StreamT>
804806
void printArg(StreamT &Stream, ArgumentAnalysis &A) const {
805807
print(Stream);

clang/lib/DPCT/RulesLang/Math/RewriterHalfPrecisionConversionAndDataMovement.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
797797
EMPTY_FACTORY_ENTRY("__ldca"),
798798
EMPTY_FACTORY_ENTRY("__ldca"),
799799
WARNING_FACTORY_ENTRY(
800-
"__ldca", DEREF_FACTORY_ENTRY("__ldca", ARG(0)),
800+
"__ldca", DEREF_FACTORY_ENTRY("__ldca", ARG_WC(0)),
801801
Diagnostics::MATH_EMULATION_EXPRESSION,
802802
std::string("__ldca"), std::string("'*'"))))
803803
// __ldcg
@@ -807,7 +807,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
807807
EMPTY_FACTORY_ENTRY("__ldcg"),
808808
EMPTY_FACTORY_ENTRY("__ldcg"),
809809
WARNING_FACTORY_ENTRY(
810-
"__ldcg", DEREF_FACTORY_ENTRY("__ldcg", ARG(0)),
810+
"__ldcg", DEREF_FACTORY_ENTRY("__ldcg", ARG_WC(0)),
811811
Diagnostics::MATH_EMULATION_EXPRESSION,
812812
std::string("__ldcg"), std::string("'*'"))))
813813
// __ldcs
@@ -817,7 +817,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
817817
EMPTY_FACTORY_ENTRY("__ldcs"),
818818
EMPTY_FACTORY_ENTRY("__ldcs"),
819819
WARNING_FACTORY_ENTRY(
820-
"__ldcs", DEREF_FACTORY_ENTRY("__ldcs", ARG(0)),
820+
"__ldcs", DEREF_FACTORY_ENTRY("__ldcs", ARG_WC(0)),
821821
Diagnostics::MATH_EMULATION_EXPRESSION,
822822
std::string("__ldcs"), std::string("'*'"))))
823823
// __ldcv
@@ -827,7 +827,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
827827
EMPTY_FACTORY_ENTRY("__ldcv"),
828828
EMPTY_FACTORY_ENTRY("__ldcv"),
829829
WARNING_FACTORY_ENTRY(
830-
"__ldcv", DEREF_FACTORY_ENTRY("__ldcv", ARG(0)),
830+
"__ldcv", DEREF_FACTORY_ENTRY("__ldcv", ARG_WC(0)),
831831
Diagnostics::MATH_EMULATION_EXPRESSION,
832832
std::string("__ldcv"), std::string("'*'"))))
833833
// __ldg
@@ -837,7 +837,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
837837
EMPTY_FACTORY_ENTRY("__ldg"),
838838
EMPTY_FACTORY_ENTRY("__ldg"),
839839
WARNING_FACTORY_ENTRY(
840-
"__ldg", DEREF_FACTORY_ENTRY("__ldg", ARG(0)),
840+
"__ldg", DEREF_FACTORY_ENTRY("__ldg", ARG_WC(0)),
841841
Diagnostics::MATH_EMULATION_EXPRESSION,
842842
std::string("__ldg"), std::string("'*'"))))
843843
// __ldlu
@@ -847,7 +847,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
847847
EMPTY_FACTORY_ENTRY("__ldlu"),
848848
EMPTY_FACTORY_ENTRY("__ldlu"),
849849
WARNING_FACTORY_ENTRY(
850-
"__ldlu", DEREF_FACTORY_ENTRY("__ldlu", ARG(0)),
850+
"__ldlu", DEREF_FACTORY_ENTRY("__ldlu", ARG_WC(0)),
851851
Diagnostics::MATH_EMULATION_EXPRESSION,
852852
std::string("__ldlu"), std::string("'*'"))))
853853
// __ll2half_rd

clang/test/dpct/macro_test.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,4 +1520,13 @@ extern EXPLICIT_DECL(half);
15201520
#undef FROMDEV3
15211521
#undef FROMDEV2
15221522

1523+
// CHECK: #define VLLM_LDG(arg) *(arg)
1524+
// CHECK-NEXT: void foo46(const float *__restrict__ input) {
1525+
// CHECK-NEXT: const float x = VLLM_LDG(&input[13]);
1526+
// CHECK-NEXT: }
1527+
#define VLLM_LDG(arg) __ldg(arg)
1528+
__global__ void foo46(const float *__restrict__ input) {
1529+
const float x = VLLM_LDG(&input[13]);
1530+
}
1531+
15231532
#endif

clang/test/dpct/memory_management.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ void checkError(cudaError_t err) {
198198

199199
void cuCheckError(CUresult err) {
200200
}
201-
// CHECK: #define PITCH(a, b, c, d) a = (float *)dpct::dpct_malloc(b, c, d);
201+
// CHECK: #define PITCH(a, b, c, d) *(a) = (float *)dpct::dpct_malloc(*(b), c, d);
202202
#define PITCH(a, b, c, d) cudaMallocPitch(a, b, c, d);
203203

204204
void testCommas() {

0 commit comments

Comments
 (0)