From 7cfefb79471a2cb799f7dd38fb7c3e1806bfa1f5 Mon Sep 17 00:00:00 2001 From: Ziran Zhang Date: Mon, 20 Jan 2025 09:55:46 +0800 Subject: [PATCH] [SYCLomatic] Refine CallExprRewriter and ExprAnalysis to cover more macro cases Signed-off-by: Ziran Zhang --- clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp | 8 +- clang/lib/DPCT/RuleInfra/CallExprRewriter.h | 59 ++++++++---- clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp | 95 ++++++++++++++++--- clang/lib/DPCT/RuleInfra/ExprAnalysis.h | 14 +-- .../RulesLang/Math/CallExprRewriterMath.cpp | 30 +++--- .../RulesLang/Math/CallExprRewriterMath.h | 12 +-- clang/lib/DPCT/RulesLang/RulesLangTexture.cpp | 2 +- 7 files changed, 150 insertions(+), 70 deletions(-) diff --git a/clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp b/clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp index 0eb661727c5b..d4ef03f23782 100644 --- a/clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp +++ b/clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp @@ -82,7 +82,6 @@ DerefExpr::DerefExpr(const Expr *E, const CallExpr *C) { } std::string CallExprRewriter::getMigratedArg(unsigned Idx) { - Analyzer.setCallSpelling(Call); Analyzer.analyze(Call->getArg(Idx)); return Analyzer.getRewritePrefix() + Analyzer.getRewriteString() + Analyzer.getRewritePostfix(); @@ -96,13 +95,14 @@ std::string CallExprRewriter::getMigratedArgWithExtraParens(unsigned Idx) { std::vector CallExprRewriter::getMigratedArgs() { std::vector ArgList; - Analyzer.setCallSpelling(Call); for (unsigned i = 0; i < Call->getNumArgs(); ++i) ArgList.emplace_back(getMigratedArg(i)); return ArgList; } -std::optional FuncCallExprRewriter::rewrite() { +std::optional +FuncCallExprRewriter::rewrite(ExprAnalysis *Analysis) { + ParentAnalysisGuard Guard(Analysis); RewriteArgList = getMigratedArgs(); return buildRewriteString(); } @@ -128,6 +128,8 @@ std::unique_ptr>>(); +ExprAnalysis *CallExprRewriter::ParentAnalysis = nullptr; + void CallExprRewriterFactoryBase::initRewriterMap() { if (DpctGlobalInfo::useSYCLCompat()) { initRewriterMapSYCLcompat(*RewriterMap); diff --git a/clang/lib/DPCT/RuleInfra/CallExprRewriter.h b/clang/lib/DPCT/RuleInfra/CallExprRewriter.h index 6d5a508dfb88..63cf7fc7a8c2 100644 --- a/clang/lib/DPCT/RuleInfra/CallExprRewriter.h +++ b/clang/lib/DPCT/RuleInfra/CallExprRewriter.h @@ -12,6 +12,8 @@ #include "Diagnostics/Diagnostics.h" #include "RuleInfra/ExprAnalysis.h" +#include "llvm/Support/SaveAndRestore.h" + namespace clang { namespace dpct { @@ -154,7 +156,9 @@ class CallExprRewriter { // factories. As a result, the access modifiers of the constructors are // supposed to be protected instead of public. CallExprRewriter(const CallExpr *Call, StringRef SourceCalleeName) - : Call(Call), SourceCalleeName(SourceCalleeName) {} + : Call(Call), SourceCalleeName(SourceCalleeName) { + Analyzer.setCallSpelling(Call); + } bool NoRewrite = false; public: @@ -163,7 +167,7 @@ class CallExprRewriter { /// This function should be overwritten to implement call expression /// rewriting. - virtual std::optional rewrite() = 0; + virtual std::optional rewrite(ExprAnalysis *Parent) = 0; // Emits a warning/error/note and/or comment depending on MsgID. For details // see Diagnostics.inc, Diagnostics.h and Diagnostics.cpp template @@ -183,13 +187,21 @@ class CallExprRewriter { return BlockLevelFormatFlag; } + static ExprAnalysis *getParentAnalysis() { return ParentAnalysis; } + protected: + struct ParentAnalysisGuard : llvm::SaveAndRestore { + ParentAnalysisGuard(ExprAnalysis *Parent) + : llvm::SaveAndRestore(ParentAnalysis, Parent) {} + }; bool BlockLevelFormatFlag = false; std::vector getMigratedArgs(); std::string getMigratedArg(unsigned Index); std::string getMigratedArgWithExtraParens(unsigned Index); StringRef getSourceCalleeName() { return SourceCalleeName; } + + static ExprAnalysis *ParentAnalysis; }; class ConditionalRewriterFactory : public CallExprRewriterFactoryBase { @@ -339,8 +351,8 @@ class AssignableRewriter : public CallExprRewriter { requestFeature(HelperFeatureEnum::device_ext); } - std::optional rewrite() override { - std::optional &&Result = Inner->rewrite(); + std::optional rewrite(ExprAnalysis *Analysis) override { + std::optional &&Result = Inner->rewrite(Analysis); if (Result.has_value()) { if ((CheckAssigned && IsAssigned) || (CheckInRetStmt && IsInRetStmt)) { if (UseDpctCheckError) { @@ -372,8 +384,8 @@ class InsertAroundRewriter : public CallExprRewriter { : CallExprRewriter(C, ""), Prefix(Prefix), Suffix(Suffix), Inner(InnerRewriter) {} - std::optional rewrite() override { - std::optional &&Result = Inner->rewrite(); + std::optional rewrite(ExprAnalysis *Analysis) override { + std::optional &&Result = Inner->rewrite(Analysis); if (Result.has_value()) return Prefix + Result.value() + Suffix; return Result; @@ -391,7 +403,7 @@ class RemoveAPIRewriter : public CallExprRewriter { : CallExprRewriter(C, CalleeName), IsAssigned(isAssigned(C)), CalleeName(CalleeName), Message(Message) {} - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { std::string Msg = Message.empty() ? "this functionality is redundant in SYCL." : Message; if (IsAssigned) { @@ -424,10 +436,10 @@ class IfElseRewriter : public CallExprRewriter { Indent = getIndent(getStmtExpansionSourceRange(C).getBegin(), SM); } - std::optional rewrite() override { - std::optional &&PredStr = Pred->rewrite(); - std::optional &&IfBlockStr = IfBlock->rewrite(); - std::optional &&ElseBlockStr = ElseBlock->rewrite(); + std::optional rewrite(ExprAnalysis *Analysis) override { + std::optional &&PredStr = Pred->rewrite(Analysis); + std::optional &&IfBlockStr = IfBlock->rewrite(Analysis); + std::optional &&ElseBlockStr = ElseBlock->rewrite(Analysis); return "if(" + PredStr.value() + "){" + NL.str() + Indent.str() + Indent.str() + IfBlockStr.value() + ";" + NL.str() + Indent.str() + "} else {" + NL.str() + Indent.str() + Indent.str() + @@ -555,7 +567,7 @@ class FuncCallExprRewriter : public CallExprRewriter { public: virtual ~FuncCallExprRewriter() {} - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *Analysis) override; friend FuncCallExprRewriterFactory; @@ -581,7 +593,7 @@ class NoRewriteFuncNameRewriter : public CallExprRewriter { NoRewrite = true; } - std::optional rewrite() override { return NewFuncName; } + std::optional rewrite(ExprAnalysis *Analysis) override { return NewFuncName; } }; struct ThrustFunctor { @@ -1175,7 +1187,8 @@ template class DeleterCallExprRewriter : public CallExprRewriter { DeleterCallExprRewriter(const CallExpr *C, StringRef Source, std::function ArgCreator) : CallExprRewriter(C, Source), Arg(ArgCreator(C)) {} - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { + ParentAnalysisGuard Guard(Analysis); std::string Result; llvm::raw_string_ostream OS(Result); OS << "delete "; @@ -1191,7 +1204,8 @@ template class ToStringExprRewriter : public CallExprRewriter { ToStringExprRewriter(const CallExpr *C, StringRef Source, std::function ArgCreator) : CallExprRewriter(C, Source), Arg(ArgCreator(C)) {} - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { + ParentAnalysisGuard Guard(Analysis); std::string Result; llvm::raw_string_ostream OS(Result); print(OS, Arg); @@ -1375,7 +1389,8 @@ class PrinterRewriter : Printer, public CallExprRewriter { PrinterRewriter(const CallExpr *C, StringRef Source, const std::function &...ArgCreators) : PrinterRewriter(C, Source, ArgCreators(C)...) {} - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { + ParentAnalysisGuard Guard(Analysis); std::string Result; llvm::raw_string_ostream OS(Result); Printer::print(OS); @@ -1398,7 +1413,8 @@ class PrinterRewriter> const CallExpr *C, StringRef Source, const std::function &...PrinterCreators) : PrinterRewriter(C, Source, PrinterCreators(C)...) {} - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { + ParentAnalysisGuard Guard(Analysis); std::string Result; llvm::raw_string_ostream OS(Result); Base::print(OS); @@ -1479,7 +1495,8 @@ class SimpleCallExprRewriter : public CallExprRewriter { const std::function(const CallExpr *)> &PrinterFunctor) : CallExprRewriter(C, Source), Printer(PrinterFunctor(C)) {} - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { + ParentAnalysisGuard Guard(Analysis); std::string Result; llvm::raw_string_ostream OS(Result); Printer.print(OS); @@ -1582,7 +1599,7 @@ class UnsupportFunctionRewriter : public CallExprRewriter { report(MsgID, false, getMsgArg(Args, CE)...); } - std::optional rewrite() override { return std::nullopt; } + std::optional rewrite(ExprAnalysis *Analysis) override { return std::nullopt; } friend UnsupportFunctionRewriterFactory; }; @@ -1609,7 +1626,7 @@ class UserDefinedRewriter : public CallExprRewriter { buildRewriterStr(Call, OS, OB); OS.flush(); } - std::optional rewrite() override { + std::optional rewrite(ExprAnalysis *Analysis) override { return ResultStr; } @@ -1701,7 +1718,7 @@ class UserDefinedRewriter : public CallExprRewriter { struct NullRewriter : public CallExprRewriter { NullRewriter(const CallExpr *C, StringRef Name) : CallExprRewriter(C, Name) {} - std::optional rewrite() override { return std::nullopt; } + std::optional rewrite(ExprAnalysis *Analysis) override { return std::nullopt; } }; struct NullRewriterFactory : public CallExprRewriterFactoryBase { diff --git a/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp b/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp index 95006034c8d4..4220737f2eb6 100644 --- a/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp +++ b/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp @@ -339,7 +339,7 @@ std::pair ExprAnalysis::getOffsetAndLength(const Expr *E, Source } else { // If the Expr is FileID or is macro arg // e.g. CALL(expr) - auto Range = getStmtExpansionSourceRange(E); + auto Range = getDefinitionRange(E->getBeginLoc(), E->getEndLoc()); BeginLoc = Range.getBegin(); EndLoc = Range.getEnd(); End = getOffset(EndLoc) + Lexer::MeasureTokenLength(EndLoc, SM, Context.getLangOpts()); @@ -885,7 +885,7 @@ void ExprAnalysis::analyzeExpr(const CallExpr *CE) { } auto Rewriter = Itr->second->create(CE); - auto Result = Rewriter->rewrite(); + auto Result = Rewriter->rewrite(this); BlockLevelFormatFlag = Rewriter->getBlockLevelFormatFlag(); if (Rewriter->isNoRewrite()) { @@ -1001,7 +1001,7 @@ void ExprAnalysis::analyzeExpr(const CXXMemberCallExpr *CMCE) { BaseType + "." + MethodName); if (Itr != CallExprRewriterFactoryBase::MethodRewriterMap->end()) { auto Rewriter = Itr->second->create(CMCE); - auto Result = Rewriter->rewrite(); + auto Result = Rewriter->rewrite(this); if (Result.has_value()) { auto ResultStr = Result.value(); addReplacement(CMCE, ResultStr); @@ -1271,9 +1271,57 @@ void ExprAnalysis::applyAllSubExprRepl() { for (std::shared_ptr Repl : SubExprRepl) { if (BlockLevelFormatFlag) Repl->setBlockLevelFormatFlag(); - DpctGlobalInfo::getInstance().addReplacement(Repl); } + SubExprRepl.clear(); +} + +bool needCleanUp(const Expr *E){ + return DpctGlobalInfo::findAncestor( + E, [&](const DynTypedNode &Node) { + return Node.get() || !Node.get(); + }); +} + +TextModification *removeWithCleanUp(SourceLocation Begin, unsigned Length, + const SourceManager &SM) { + Token NextToken; + if (!Lexer::getRawToken(Begin.getLocWithOffset(Length), NextToken, SM, + DpctGlobalInfo::getContext().getLangOpts(), true) && + NextToken.getKind() == tok::semi) { + Length += NextToken.getLength(); + auto BeginData = SM.getCharacterData(Begin), + EndData = BeginData + Length; + unsigned Indent = 0, Trailing = 0; + auto Ch= *--BeginData; + while (isspace(Ch) && Ch != '\n' && Ch != '\r' && Ch != '{') { + ++Indent; + Ch = *--BeginData; + } + Ch = *EndData; + while (isspace(Ch) || Ch == '\\') { + ++Trailing; + if (Ch == '\n' || Ch == '\r') { + Ch = *++EndData; + if (Ch == '\n' || Ch == '\r') + ++Trailing; + break; + } + Ch = *++EndData; + } + Begin = Begin.getLocWithOffset(-Indent); + Length += Indent + Trailing; + } + return new ReplaceText(Begin, Length, ""); +} + +TextModification *ExprAnalysis::getReplacement() { + if (!hasReplacement()) + return nullptr; + std::string Repl = getReplacedString(); + if (E && Repl.empty() && needCleanUp(E)) + return removeWithCleanUp(SrcBeginLoc, SrcLength, SM); + return new ReplaceText(SrcBeginLoc, SrcLength, std::move(Repl)); } const std::string &ArgumentAnalysis::getDefaultArgument(const Expr *E) { @@ -2168,6 +2216,25 @@ void KernelConfigAnalysis::analyze(const Expr *E, unsigned int Idx, ArgumentAnalysis::analyze(E); } +void ExprAnalysis::applySubExprReplToParent() { + if (auto Parent = CallExprRewriter::getParentAnalysis()) { + for (const auto &Repl : SubExprRepl) { + auto File = SM.getFileManager().getFileRef(Repl->getFilePath()); + if (!File || Parent->FileId != SM.translateFile(File.get()) || + Repl->getOffset() < Parent->SrcBegin || + Repl->getOffset() + Repl->getLength() > + Parent->SrcBegin + Parent->SrcLength) { + Parent->addExtReplacement(Repl); + } else { + Parent->addReplacement(Repl->getOffset() - Parent->SrcBegin, + Repl->getLength(), + Repl->getReplacementText().str()); + } + } + SubExprRepl.clear(); + } +} + std::string ArgumentAnalysis::getRewriteString() { // Find rewrite range auto RewriteRange = getLocInCallSpelling(getTargetExpr()); @@ -2182,16 +2249,20 @@ std::string ArgumentAnalysis::getRewriteString() { StringReplacements SRs; SRs.init(std::move(OriginalStr)); - for (std::shared_ptr SubRepl : SubExprRepl) { - if (isInRange(RewriteRangeBegin, RewriteRangeEnd, SubRepl->getFilePath(), - SubRepl->getOffset()) && - isInRange(RewriteRangeBegin, RewriteRangeEnd, SubRepl->getFilePath(), - SubRepl->getOffset() + SubRepl->getLength())) { - SRs.addStringReplacement( - SubRepl->getOffset() - SM.getDecomposedLoc(RewriteRangeBegin).second, - SubRepl->getLength(), SubRepl->getReplacementText().str()); + for (auto Iter = SubExprRepl.begin(); Iter != SubExprRepl.end();) { + auto &Repl = **Iter; + if (isInRange(RewriteRangeBegin, RewriteRangeEnd, Repl.getFilePath(), + Repl.getOffset()) && + isInRange(RewriteRangeBegin, RewriteRangeEnd, Repl.getFilePath(), + Repl.getOffset() + Repl.getLength())) { + SRs.addStringReplacement(Repl.getOffset() - DL.second, Repl.getLength(), + Repl.getReplacementText().str()); + Iter = SubExprRepl.erase(Iter); + } else { + ++Iter; } } + applySubExprReplToParent(); return SRs.getReplacedString(); } diff --git a/clang/lib/DPCT/RuleInfra/ExprAnalysis.h b/clang/lib/DPCT/RuleInfra/ExprAnalysis.h index 2576b1089a7f..7bffa9af1c92 100644 --- a/clang/lib/DPCT/RuleInfra/ExprAnalysis.h +++ b/clang/lib/DPCT/RuleInfra/ExprAnalysis.h @@ -239,18 +239,7 @@ class ExprAnalysis { // This function is not re-enterable, if caller need to check if it returns // nullptr, caller need to use temp variable to save the return value, then // check. Don't call twice for same Replacement. - inline TextModification *getReplacement() { - bool hasRepl = hasReplacement(); - std::string Repl = getReplacedString(); - if (E) { - auto Range = getDefinitionRange(E->getBeginLoc(), E->getEndLoc()); - if (!isSameLocation(Range.getBegin(), Range.getEnd())) { - return hasRepl ? new ReplaceStmt(E, true, Repl) : nullptr; - } - } - return hasRepl ? new ReplaceText(SrcBeginLoc, SrcLength, std::move(Repl)) - : nullptr; - } + TextModification *getReplacement(); inline void clearReplacement() { ReplSet.reset(); } @@ -385,6 +374,7 @@ class ExprAnalysis { } void applyAllSubExprRepl(); + void applySubExprReplToParent(); inline std::vector> &getSubExprRepl() { return SubExprRepl; }; diff --git a/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp b/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp index 94e58fc5c6df..e0325dfe3eee 100644 --- a/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp +++ b/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp @@ -12,13 +12,14 @@ namespace clang { namespace dpct { -std::optional MathFuncNameRewriter::rewrite() { +std::optional MathFuncNameRewriter::rewrite(ExprAnalysis *Analysis) { // If the function is not a target math function, do not migrate it if (!isTargetMathFunction(Call->getDirectCallee())) { // No actions needed here, just return an empty string return {}; } + ParentAnalysisGuard Guard(Analysis); reportUnsupportedRoundingMode(); RewriteArgList = getMigratedArgs(); auto NewFuncName = getNewFuncName(); @@ -158,7 +159,6 @@ std::string MathFuncNameRewriter::getNewFuncName() { std::string ArgT = Arg->IgnoreImplicit()->getType().getCanonicalType().getAsString( PrintingPolicy(LO)); - std::string ArgExpr = Arg->getStmtClassName(); auto DRE = dyn_cast(Arg->IgnoreCasts()); auto IL = dyn_cast(Arg->IgnoreCasts()); std::string ParamType = "float"; @@ -192,7 +192,6 @@ std::string MathFuncNameRewriter::getNewFuncName() { std::string ArgT = Arg->IgnoreImplicit()->getType().getCanonicalType().getAsString( PrintingPolicy(LO)); - std::string ArgExpr = Arg->getStmtClassName(); auto DRE = dyn_cast(Arg->IgnoreCasts()); auto IL = dyn_cast(Arg->IgnoreCasts()); std::string ParamType = "double"; @@ -255,7 +254,8 @@ std::string MathFuncNameRewriter::getNewFuncName() { return NewFuncName; } -std::optional MathCallExprRewriter::rewrite() { +std::optional MathCallExprRewriter::rewrite(ExprAnalysis *Analysis) { + ParentAnalysisGuard Guard(Analysis); RewriteArgList = getMigratedArgs(); setTargetCalleeName(SourceCalleeName.str()); return buildRewriteString(); @@ -268,17 +268,18 @@ void MathCallExprRewriter::reportUnsupportedRoundingMode() { } } -std::optional MathUnsupportedRewriter::rewrite() { +std::optional MathUnsupportedRewriter::rewrite(ExprAnalysis *Analysis) { report(Diagnostics::API_NOT_MIGRATED, false, MapNames::ITFName.at(SourceCalleeName.str())); - return Base::rewrite(); + return Base::rewrite(Analysis); } -std::optional MathTypeCastRewriter::rewrite() { +std::optional MathTypeCastRewriter::rewrite(ExprAnalysis *Analysis) { auto FD = Call->getDirectCallee(); if (!FD || !FD->hasAttr()) - return Base::rewrite(); + return Base::rewrite(Analysis); + ParentAnalysisGuard Guard(Analysis); using SSMap = std::map; static SSMap RoundingModeMap{{"", "automatic"}, {"rd", "rtn"}, @@ -361,7 +362,7 @@ std::optional MathTypeCastRewriter::rewrite() { return ReplStr; } -std::optional MathSimulatedRewriter::rewrite() { +std::optional MathSimulatedRewriter::rewrite(ExprAnalysis *Analysis) { std::string NamespaceStr; auto DRE = dyn_cast(Call->getCallee()->IgnoreImpCasts()); if (DRE) { @@ -377,7 +378,7 @@ std::optional MathSimulatedRewriter::rewrite() { auto FD = Call->getDirectCallee(); if (!FD) - return Base::rewrite(); + return Base::rewrite(Analysis); if (dpct::DpctGlobalInfo::isInAnalysisScope(FD->getBeginLoc())) { return {}; @@ -393,7 +394,7 @@ std::optional MathSimulatedRewriter::rewrite() { if (!FD->hasAttr() && ContextFD && !ContextFD->hasAttr() && !ContextFD->hasAttr()) - return Base::rewrite(); + return Base::rewrite(Analysis); // Do not need to report warnings for pow, funnelshift, or drcp migrations if (SourceCalleeName != "pow" && SourceCalleeName != "powf" && @@ -403,6 +404,7 @@ std::optional MathSimulatedRewriter::rewrite() { report(Diagnostics::MATH_EMULATION, false, MapNames::ITFName.at(SourceCalleeName.str()), TargetCalleeName); + ParentAnalysisGuard Guard(Analysis); const std::string FuncName = SourceCalleeName.str(); std::string ReplStr; llvm::raw_string_ostream OS(ReplStr); @@ -412,7 +414,6 @@ std::optional MathSimulatedRewriter::rewrite() { auto Arg = Call->getArg(0); std::string ArgT = Arg->IgnoreImplicit()->getType().getAsString( PrintingPolicy(LangOptions())); - std::string ArgExpr = Arg->getStmtClassName(); auto DRE = dyn_cast(Arg->IgnoreCasts()); if (ArgT == "int") { if (FuncName == "frexpf") { @@ -462,7 +463,6 @@ std::optional MathSimulatedRewriter::rewrite() { auto Arg = Call->getArg(0); std::string ArgT = Arg->IgnoreImplicit()->getType().getAsString( PrintingPolicy(LangOptions())); - std::string ArgExpr = Arg->getStmtClassName(); auto DRE = dyn_cast(Arg->IgnoreCasts()); if (ArgT == "int") { if (FuncName == "remquof") { @@ -483,7 +483,6 @@ std::optional MathSimulatedRewriter::rewrite() { auto Arg = Call->getArg(1); std::string ArgT = Arg->IgnoreImplicit()->getType().getAsString( PrintingPolicy(LangOptions())); - std::string ArgExpr = Arg->getStmtClassName(); auto DRE = dyn_cast(Arg->IgnoreCasts()); if (ArgT == "int") { if (FuncName == "remquof") { @@ -621,8 +620,9 @@ std::optional MathSimulatedRewriter::rewrite() { return ReplStr; } -std::optional MathBinaryOperatorRewriter::rewrite() { +std::optional MathBinaryOperatorRewriter::rewrite(ExprAnalysis *Analysis) { reportUnsupportedRoundingMode(); + ParentAnalysisGuard Guard(Analysis); if (SourceCalleeName == "__hneg" || SourceCalleeName == "__hneg2") { setLHS(""); setRHS(getMigratedArgWithExtraParens(0)); diff --git a/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h b/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h index 0558a02fbde5..04a62fa95045 100644 --- a/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h +++ b/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.h @@ -35,7 +35,7 @@ using NoRewriteFuncNameRewriterFactory = /// Base class for rewriting math function calls class MathCallExprRewriter : public FuncCallExprRewriter { public: - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *) override; protected: MathCallExprRewriter(const CallExpr *Call, StringRef SourceCalleeName, @@ -53,7 +53,7 @@ class MathUnsupportedRewriter : public MathCallExprRewriter { StringRef TargetCalleeName) : Base(Call, SourceCalleeName, TargetCalleeName) {} - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *) override; friend MathUnsupportedRewriterFactory; }; @@ -66,7 +66,7 @@ class MathTypeCastRewriter : public MathCallExprRewriter { StringRef TargetCalleeName) : Base(Call, SourceCalleeName, TargetCalleeName) {} - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *) override; friend MathTypeCastRewriterFactory; }; @@ -79,7 +79,7 @@ class MathSimulatedRewriter : public MathCallExprRewriter { StringRef TargetCalleeName) : Base(Call, SourceCalleeName, TargetCalleeName) {} - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *) override; friend MathSimulatedRewriterFactory; }; @@ -98,7 +98,7 @@ class MathBinaryOperatorRewriter : public MathCallExprRewriter { public: virtual ~MathBinaryOperatorRewriter() {} - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *) override; protected: void setLHS(std::string L) { LHS = L; } @@ -122,7 +122,7 @@ class MathFuncNameRewriter : public MathCallExprRewriter { : MathCallExprRewriter(Call, SourceCalleeName, TargetCalleeName) {} public: - virtual std::optional rewrite() override; + virtual std::optional rewrite(ExprAnalysis *) override; protected: std::string getNewFuncName(); diff --git a/clang/lib/DPCT/RulesLang/RulesLangTexture.cpp b/clang/lib/DPCT/RulesLang/RulesLangTexture.cpp index 0aec74a8fa5c..efa274e68c1c 100644 --- a/clang/lib/DPCT/RulesLang/RulesLangTexture.cpp +++ b/clang/lib/DPCT/RulesLang/RulesLangTexture.cpp @@ -934,7 +934,7 @@ void TextureRule::runRule(const MatchFinder::MatchResult &Result) { const Expr *, RenameWithSuffix, false, StringRef>>>( CE, Name, CE->getArg(0), true, RenameWithSuffix("set", MethodName), Value)); - std::optional Result = Rewriter->rewrite(); + std::optional Result = Rewriter->rewrite(nullptr); if (Result.has_value()) emplaceTransformation( new ReplaceStmt(CE, true, std::move(Result).value()));