Skip to content

Commit bdda806

Browse files
committed
[SYCLomatic] Refine CallExprRewriter and ExprAnalysis to cover more macro cases
Signed-off-by: Ziran Zhang <ziran.zhang@intel.com>
1 parent a1a8876 commit bdda806

File tree

7 files changed

+149
-69
lines changed

7 files changed

+149
-69
lines changed

clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ DerefExpr::DerefExpr(const Expr *E, const CallExpr *C) {
8282
}
8383

8484
std::string CallExprRewriter::getMigratedArg(unsigned Idx) {
85-
Analyzer.setCallSpelling(Call);
8685
Analyzer.analyze(Call->getArg(Idx));
8786
return Analyzer.getRewritePrefix() + Analyzer.getRewriteString() +
8887
Analyzer.getRewritePostfix();
@@ -96,13 +95,14 @@ std::string CallExprRewriter::getMigratedArgWithExtraParens(unsigned Idx) {
9695

9796
std::vector<std::string> CallExprRewriter::getMigratedArgs() {
9897
std::vector<std::string> ArgList;
99-
Analyzer.setCallSpelling(Call);
10098
for (unsigned i = 0; i < Call->getNumArgs(); ++i)
10199
ArgList.emplace_back(getMigratedArg(i));
102100
return ArgList;
103101
}
104102

105-
std::optional<std::string> FuncCallExprRewriter::rewrite() {
103+
std::optional<std::string>
104+
FuncCallExprRewriter::rewrite(ExprAnalysis *Analysis) {
105+
ParentAnalysisGuard Guard(Analysis);
106106
RewriteArgList = getMigratedArgs();
107107
return buildRewriteString();
108108
}
@@ -128,6 +128,8 @@ std::unique_ptr<std::unordered_map<
128128
CallExprRewriterFactoryBase::MethodRewriterMap = std::make_unique<std::unordered_map<
129129
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
130130

131+
ExprAnalysis *CallExprRewriter::ParentAnalysis = nullptr;
132+
131133
void CallExprRewriterFactoryBase::initRewriterMap() {
132134
if (DpctGlobalInfo::useSYCLCompat()) {
133135
initRewriterMapSYCLcompat(*RewriterMap);

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "Diagnostics/Diagnostics.h"
1313
#include "RuleInfra/ExprAnalysis.h"
1414

15+
#include "llvm/Support/SaveAndRestore.h"
16+
1517
namespace clang {
1618
namespace dpct {
1719

@@ -154,7 +156,9 @@ class CallExprRewriter {
154156
// factories. As a result, the access modifiers of the constructors are
155157
// supposed to be protected instead of public.
156158
CallExprRewriter(const CallExpr *Call, StringRef SourceCalleeName)
157-
: Call(Call), SourceCalleeName(SourceCalleeName) {}
159+
: Call(Call), SourceCalleeName(SourceCalleeName) {
160+
Analyzer.setCallSpelling(Call);
161+
}
158162
bool NoRewrite = false;
159163

160164
public:
@@ -163,7 +167,7 @@ class CallExprRewriter {
163167

164168
/// This function should be overwritten to implement call expression
165169
/// rewriting.
166-
virtual std::optional<std::string> rewrite() = 0;
170+
virtual std::optional<std::string> rewrite(ExprAnalysis *Parent) = 0;
167171
// Emits a warning/error/note and/or comment depending on MsgID. For details
168172
// see Diagnostics.inc, Diagnostics.h and Diagnostics.cpp
169173
template <typename IDTy, typename... Ts>
@@ -183,13 +187,21 @@ class CallExprRewriter {
183187
return BlockLevelFormatFlag;
184188
}
185189

190+
static ExprAnalysis *getParentAnalysis() { return ParentAnalysis; }
191+
186192
protected:
193+
struct ParentAnalysisGuard : llvm::SaveAndRestore<ExprAnalysis *> {
194+
ParentAnalysisGuard(ExprAnalysis *Parent)
195+
: llvm::SaveAndRestore<ExprAnalysis *>(ParentAnalysis, Parent) {}
196+
};
187197
bool BlockLevelFormatFlag = false;
188198
std::vector<std::string> getMigratedArgs();
189199
std::string getMigratedArg(unsigned Index);
190200
std::string getMigratedArgWithExtraParens(unsigned Index);
191201

192202
StringRef getSourceCalleeName() { return SourceCalleeName; }
203+
204+
static ExprAnalysis *ParentAnalysis;
193205
};
194206

195207
class ConditionalRewriterFactory : public CallExprRewriterFactoryBase {
@@ -339,8 +351,8 @@ class AssignableRewriter : public CallExprRewriter {
339351
requestFeature(HelperFeatureEnum::device_ext);
340352
}
341353

342-
std::optional<std::string> rewrite() override {
343-
std::optional<std::string> &&Result = Inner->rewrite();
354+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
355+
std::optional<std::string> &&Result = Inner->rewrite(Analysis);
344356
if (Result.has_value()) {
345357
if ((CheckAssigned && IsAssigned) || (CheckInRetStmt && IsInRetStmt)) {
346358
if (UseDpctCheckError) {
@@ -372,8 +384,8 @@ class InsertAroundRewriter : public CallExprRewriter {
372384
: CallExprRewriter(C, ""), Prefix(Prefix), Suffix(Suffix),
373385
Inner(InnerRewriter) {}
374386

375-
std::optional<std::string> rewrite() override {
376-
std::optional<std::string> &&Result = Inner->rewrite();
387+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
388+
std::optional<std::string> &&Result = Inner->rewrite(Analysis);
377389
if (Result.has_value())
378390
return Prefix + Result.value() + Suffix;
379391
return Result;
@@ -391,7 +403,7 @@ class RemoveAPIRewriter : public CallExprRewriter {
391403
: CallExprRewriter(C, CalleeName), IsAssigned(isAssigned(C)),
392404
CalleeName(CalleeName), Message(Message) {}
393405

394-
std::optional<std::string> rewrite() override {
406+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
395407
std::string Msg =
396408
Message.empty() ? "this functionality is redundant in SYCL." : Message;
397409
if (IsAssigned) {
@@ -424,10 +436,10 @@ class IfElseRewriter : public CallExprRewriter {
424436
Indent = getIndent(getStmtExpansionSourceRange(C).getBegin(), SM);
425437
}
426438

427-
std::optional<std::string> rewrite() override {
428-
std::optional<std::string> &&PredStr = Pred->rewrite();
429-
std::optional<std::string> &&IfBlockStr = IfBlock->rewrite();
430-
std::optional<std::string> &&ElseBlockStr = ElseBlock->rewrite();
439+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
440+
std::optional<std::string> &&PredStr = Pred->rewrite(Analysis);
441+
std::optional<std::string> &&IfBlockStr = IfBlock->rewrite(Analysis);
442+
std::optional<std::string> &&ElseBlockStr = ElseBlock->rewrite(Analysis);
431443
return "if(" + PredStr.value() + "){" + NL.str() + Indent.str() +
432444
Indent.str() + IfBlockStr.value() + ";" + NL.str() +
433445
Indent.str() + "} else {" + NL.str() + Indent.str() + Indent.str() +
@@ -555,7 +567,7 @@ class FuncCallExprRewriter : public CallExprRewriter {
555567
public:
556568
virtual ~FuncCallExprRewriter() {}
557569

558-
virtual std::optional<std::string> rewrite() override;
570+
virtual std::optional<std::string> rewrite(ExprAnalysis *Analysis) override;
559571

560572
friend FuncCallExprRewriterFactory;
561573

@@ -581,7 +593,7 @@ class NoRewriteFuncNameRewriter : public CallExprRewriter {
581593
NoRewrite = true;
582594
}
583595

584-
std::optional<std::string> rewrite() override { return NewFuncName; }
596+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override { return NewFuncName; }
585597
};
586598

587599
struct ThrustFunctor {
@@ -1175,7 +1187,8 @@ template <class ArgT> class DeleterCallExprRewriter : public CallExprRewriter {
11751187
DeleterCallExprRewriter(const CallExpr *C, StringRef Source,
11761188
std::function<ArgT(const CallExpr *)> ArgCreator)
11771189
: CallExprRewriter(C, Source), Arg(ArgCreator(C)) {}
1178-
std::optional<std::string> rewrite() override {
1190+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
1191+
ParentAnalysisGuard Guard(Analysis);
11791192
std::string Result;
11801193
llvm::raw_string_ostream OS(Result);
11811194
OS << "delete ";
@@ -1191,7 +1204,8 @@ template <class ArgT> class ToStringExprRewriter : public CallExprRewriter {
11911204
ToStringExprRewriter(const CallExpr *C, StringRef Source,
11921205
std::function<ArgT(const CallExpr *)> ArgCreator)
11931206
: CallExprRewriter(C, Source), Arg(ArgCreator(C)) {}
1194-
std::optional<std::string> rewrite() override {
1207+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
1208+
ParentAnalysisGuard Guard(Analysis);
11951209
std::string Result;
11961210
llvm::raw_string_ostream OS(Result);
11971211
print(OS, Arg);
@@ -1375,7 +1389,8 @@ class PrinterRewriter : Printer, public CallExprRewriter {
13751389
PrinterRewriter(const CallExpr *C, StringRef Source,
13761390
const std::function<ArgsT(const CallExpr *)> &...ArgCreators)
13771391
: PrinterRewriter(C, Source, ArgCreators(C)...) {}
1378-
std::optional<std::string> rewrite() override {
1392+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
1393+
ParentAnalysisGuard Guard(Analysis);
13791394
std::string Result;
13801395
llvm::raw_string_ostream OS(Result);
13811396
Printer::print(OS);
@@ -1398,7 +1413,8 @@ class PrinterRewriter<MultiStmtsPrinter<StmtPrinters...>>
13981413
const CallExpr *C, StringRef Source,
13991414
const std::function<StmtPrinters(const CallExpr *)> &...PrinterCreators)
14001415
: PrinterRewriter(C, Source, PrinterCreators(C)...) {}
1401-
std::optional<std::string> rewrite() override {
1416+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
1417+
ParentAnalysisGuard Guard(Analysis);
14021418
std::string Result;
14031419
llvm::raw_string_ostream OS(Result);
14041420
Base::print(OS);
@@ -1479,7 +1495,8 @@ class SimpleCallExprRewriter : public CallExprRewriter {
14791495
const std::function<CallExprPrinter<CalleeT, ArgsT...>(const CallExpr *)>
14801496
&PrinterFunctor)
14811497
: CallExprRewriter(C, Source), Printer(PrinterFunctor(C)) {}
1482-
std::optional<std::string> rewrite() override {
1498+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
1499+
ParentAnalysisGuard Guard(Analysis);
14831500
std::string Result;
14841501
llvm::raw_string_ostream OS(Result);
14851502
Printer.print(OS);
@@ -1582,7 +1599,7 @@ class UnsupportFunctionRewriter : public CallExprRewriter {
15821599
report(MsgID, false, getMsgArg(Args, CE)...);
15831600
}
15841601

1585-
std::optional<std::string> rewrite() override { return std::nullopt; }
1602+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override { return std::nullopt; }
15861603

15871604
friend UnsupportFunctionRewriterFactory<MsgArgs...>;
15881605
};
@@ -1609,7 +1626,7 @@ class UserDefinedRewriter : public CallExprRewriter {
16091626
buildRewriterStr(Call, OS, OB);
16101627
OS.flush();
16111628
}
1612-
std::optional<std::string> rewrite() override {
1629+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
16131630
return ResultStr;
16141631
}
16151632

@@ -1701,7 +1718,7 @@ class UserDefinedRewriter : public CallExprRewriter {
17011718
struct NullRewriter : public CallExprRewriter {
17021719
NullRewriter(const CallExpr *C, StringRef Name) : CallExprRewriter(C, Name) {}
17031720

1704-
std::optional<std::string> rewrite() override { return std::nullopt; }
1721+
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override { return std::nullopt; }
17051722
};
17061723

17071724
struct NullRewriterFactory : public CallExprRewriterFactoryBase {

clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ void ExprAnalysis::analyzeExpr(const CallExpr *CE) {
885885
}
886886

887887
auto Rewriter = Itr->second->create(CE);
888-
auto Result = Rewriter->rewrite();
888+
auto Result = Rewriter->rewrite(this);
889889
BlockLevelFormatFlag = Rewriter->getBlockLevelFormatFlag();
890890

891891
if (Rewriter->isNoRewrite()) {
@@ -1001,7 +1001,7 @@ void ExprAnalysis::analyzeExpr(const CXXMemberCallExpr *CMCE) {
10011001
BaseType + "." + MethodName);
10021002
if (Itr != CallExprRewriterFactoryBase::MethodRewriterMap->end()) {
10031003
auto Rewriter = Itr->second->create(CMCE);
1004-
auto Result = Rewriter->rewrite();
1004+
auto Result = Rewriter->rewrite(this);
10051005
if (Result.has_value()) {
10061006
auto ResultStr = Result.value();
10071007
addReplacement(CMCE, ResultStr);
@@ -1271,9 +1271,57 @@ void ExprAnalysis::applyAllSubExprRepl() {
12711271
for (std::shared_ptr<ExtReplacement> Repl : SubExprRepl) {
12721272
if (BlockLevelFormatFlag)
12731273
Repl->setBlockLevelFormatFlag();
1274-
12751274
DpctGlobalInfo::getInstance().addReplacement(Repl);
12761275
}
1276+
SubExprRepl.clear();
1277+
}
1278+
1279+
bool needCleanUp(const Expr *E){
1280+
return DpctGlobalInfo::findAncestor<CompoundStmt>(
1281+
E, [&](const DynTypedNode &Node) {
1282+
return Node.get<CompoundStmt>() || !Node.get<ExprWithCleanups>();
1283+
});
1284+
}
1285+
1286+
TextModification *removeWithCleanUp(SourceLocation Begin, unsigned Length,
1287+
const SourceManager &SM) {
1288+
Token NextToken;
1289+
if (!Lexer::getRawToken(Begin.getLocWithOffset(Length), NextToken, SM,
1290+
DpctGlobalInfo::getContext().getLangOpts(), true) &&
1291+
NextToken.getKind() == tok::semi) {
1292+
Length += NextToken.getLength();
1293+
auto BeginData = SM.getCharacterData(Begin),
1294+
EndData = BeginData + Length;
1295+
unsigned Indent = 0, Trailing = 0;
1296+
auto Ch= *--BeginData;
1297+
while (isspace(Ch) && Ch != '\n' && Ch != '\r' && Ch != '{') {
1298+
++Indent;
1299+
Ch = *--BeginData;
1300+
}
1301+
Ch = *EndData;
1302+
while (isspace(Ch) || Ch == '\\') {
1303+
++Trailing;
1304+
if (Ch == '\n' || Ch == '\r') {
1305+
Ch = *++EndData;
1306+
if (Ch == '\n' || Ch == '\r')
1307+
++Trailing;
1308+
break;
1309+
}
1310+
Ch = *++EndData;
1311+
}
1312+
Begin = Begin.getLocWithOffset(-Indent);
1313+
Length += Indent + Trailing;
1314+
}
1315+
return new ReplaceText(Begin, Length, "");
1316+
}
1317+
1318+
TextModification *ExprAnalysis::getReplacement() {
1319+
if (!hasReplacement())
1320+
return nullptr;
1321+
std::string Repl = getReplacedString();
1322+
if (E && Repl.empty() && needCleanUp(E))
1323+
return removeWithCleanUp(SrcBeginLoc, SrcLength, SM);
1324+
return new ReplaceText(SrcBeginLoc, SrcLength, std::move(Repl));
12771325
}
12781326

12791327
const std::string &ArgumentAnalysis::getDefaultArgument(const Expr *E) {
@@ -2168,6 +2216,25 @@ void KernelConfigAnalysis::analyze(const Expr *E, unsigned int Idx,
21682216
ArgumentAnalysis::analyze(E);
21692217
}
21702218

2219+
void ExprAnalysis::applySubExprReplToParent() {
2220+
if (auto Parent = CallExprRewriter::getParentAnalysis()) {
2221+
for (const auto &Repl : SubExprRepl) {
2222+
auto File = SM.getFileManager().getFileRef(Repl->getFilePath());
2223+
if (!File || Parent->FileId != SM.translateFile(File.get()) ||
2224+
Repl->getOffset() < Parent->SrcBegin ||
2225+
Repl->getOffset() + Repl->getLength() >
2226+
Parent->SrcBegin + Parent->SrcLength) {
2227+
Parent->addExtReplacement(Repl);
2228+
} else {
2229+
Parent->addReplacement(Repl->getOffset() - Parent->SrcBegin,
2230+
Repl->getLength(),
2231+
Repl->getReplacementText().str());
2232+
}
2233+
}
2234+
SubExprRepl.clear();
2235+
}
2236+
}
2237+
21712238
std::string ArgumentAnalysis::getRewriteString() {
21722239
// Find rewrite range
21732240
auto RewriteRange = getLocInCallSpelling(getTargetExpr());
@@ -2182,16 +2249,20 @@ std::string ArgumentAnalysis::getRewriteString() {
21822249

21832250
StringReplacements SRs;
21842251
SRs.init(std::move(OriginalStr));
2185-
for (std::shared_ptr<ExtReplacement> SubRepl : SubExprRepl) {
2186-
if (isInRange(RewriteRangeBegin, RewriteRangeEnd, SubRepl->getFilePath(),
2187-
SubRepl->getOffset()) &&
2188-
isInRange(RewriteRangeBegin, RewriteRangeEnd, SubRepl->getFilePath(),
2189-
SubRepl->getOffset() + SubRepl->getLength())) {
2190-
SRs.addStringReplacement(
2191-
SubRepl->getOffset() - SM.getDecomposedLoc(RewriteRangeBegin).second,
2192-
SubRepl->getLength(), SubRepl->getReplacementText().str());
2252+
for (auto Iter = SubExprRepl.begin(); Iter != SubExprRepl.end();) {
2253+
auto &Repl = **Iter;
2254+
if (isInRange(RewriteRangeBegin, RewriteRangeEnd, Repl.getFilePath(),
2255+
Repl.getOffset()) &&
2256+
isInRange(RewriteRangeBegin, RewriteRangeEnd, Repl.getFilePath(),
2257+
Repl.getOffset() + Repl.getLength())) {
2258+
SRs.addStringReplacement(Repl.getOffset() - DL.second, Repl.getLength(),
2259+
Repl.getReplacementText().str());
2260+
Iter = SubExprRepl.erase(Iter);
2261+
} else {
2262+
++Iter;
21932263
}
21942264
}
2265+
applySubExprReplToParent();
21952266
return SRs.getReplacedString();
21962267
}
21972268

clang/lib/DPCT/RuleInfra/ExprAnalysis.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,18 +239,7 @@ class ExprAnalysis {
239239
// This function is not re-enterable, if caller need to check if it returns
240240
// nullptr, caller need to use temp variable to save the return value, then
241241
// check. Don't call twice for same Replacement.
242-
inline TextModification *getReplacement() {
243-
bool hasRepl = hasReplacement();
244-
std::string Repl = getReplacedString();
245-
if (E) {
246-
auto Range = getDefinitionRange(E->getBeginLoc(), E->getEndLoc());
247-
if (!isSameLocation(Range.getBegin(), Range.getEnd())) {
248-
return hasRepl ? new ReplaceStmt(E, true, Repl) : nullptr;
249-
}
250-
}
251-
return hasRepl ? new ReplaceText(SrcBeginLoc, SrcLength, std::move(Repl))
252-
: nullptr;
253-
}
242+
TextModification *getReplacement();
254243

255244
inline void clearReplacement() { ReplSet.reset(); }
256245

@@ -385,6 +374,7 @@ class ExprAnalysis {
385374
}
386375

387376
void applyAllSubExprRepl();
377+
void applySubExprReplToParent();
388378
inline std::vector<std::shared_ptr<ExtReplacement>> &getSubExprRepl() {
389379
return SubExprRepl;
390380
};

0 commit comments

Comments
 (0)