Skip to content

[SYCLomatic] Refine CallExprRewriter and ExprAnalysis to cover more macro cases #2628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: SYCLomatic
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ DerefExpr::DerefExpr(const Expr *E, const CallExpr *C) {
}

std::string CallExprRewriter::getMigratedArg(unsigned Idx) {
Analyzer.setCallSpelling(Call);
Analyzer.analyze(Call->getArg(Idx));
return Analyzer.getRewritePrefix() + Analyzer.getRewriteString() +
Analyzer.getRewritePostfix();
Expand All @@ -96,13 +95,14 @@ std::string CallExprRewriter::getMigratedArgWithExtraParens(unsigned Idx) {

std::vector<std::string> CallExprRewriter::getMigratedArgs() {
std::vector<std::string> ArgList;
Analyzer.setCallSpelling(Call);
for (unsigned i = 0; i < Call->getNumArgs(); ++i)
ArgList.emplace_back(getMigratedArg(i));
return ArgList;
}

std::optional<std::string> FuncCallExprRewriter::rewrite() {
std::optional<std::string>
FuncCallExprRewriter::rewrite(ExprAnalysis *Analysis) {
ParentAnalysisGuard Guard(Analysis);
RewriteArgList = getMigratedArgs();
return buildRewriteString();
}
Expand All @@ -128,6 +128,8 @@ std::unique_ptr<std::unordered_map<
CallExprRewriterFactoryBase::MethodRewriterMap = std::make_unique<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();

ExprAnalysis *CallExprRewriter::ParentAnalysis = nullptr;

void CallExprRewriterFactoryBase::initRewriterMap() {
if (DpctGlobalInfo::useSYCLCompat()) {
initRewriterMapSYCLcompat(*RewriterMap);
Expand Down
59 changes: 38 additions & 21 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include "Diagnostics/Diagnostics.h"
#include "RuleInfra/ExprAnalysis.h"

#include "llvm/Support/SaveAndRestore.h"

namespace clang {
namespace dpct {

Expand Down Expand Up @@ -154,7 +156,9 @@ class CallExprRewriter {
// factories. As a result, the access modifiers of the constructors are
// supposed to be protected instead of public.
CallExprRewriter(const CallExpr *Call, StringRef SourceCalleeName)
: Call(Call), SourceCalleeName(SourceCalleeName) {}
: Call(Call), SourceCalleeName(SourceCalleeName) {
Analyzer.setCallSpelling(Call);
}
bool NoRewrite = false;

public:
Expand All @@ -163,7 +167,7 @@ class CallExprRewriter {

/// This function should be overwritten to implement call expression
/// rewriting.
virtual std::optional<std::string> rewrite() = 0;
virtual std::optional<std::string> rewrite(ExprAnalysis *Parent) = 0;
// Emits a warning/error/note and/or comment depending on MsgID. For details
// see Diagnostics.inc, Diagnostics.h and Diagnostics.cpp
template <typename IDTy, typename... Ts>
Expand All @@ -183,13 +187,21 @@ class CallExprRewriter {
return BlockLevelFormatFlag;
}

static ExprAnalysis *getParentAnalysis() { return ParentAnalysis; }

protected:
struct ParentAnalysisGuard : llvm::SaveAndRestore<ExprAnalysis *> {
ParentAnalysisGuard(ExprAnalysis *Parent)
: llvm::SaveAndRestore<ExprAnalysis *>(ParentAnalysis, Parent) {}
};
bool BlockLevelFormatFlag = false;
std::vector<std::string> getMigratedArgs();
std::string getMigratedArg(unsigned Index);
std::string getMigratedArgWithExtraParens(unsigned Index);

StringRef getSourceCalleeName() { return SourceCalleeName; }

static ExprAnalysis *ParentAnalysis;
};

class ConditionalRewriterFactory : public CallExprRewriterFactoryBase {
Expand Down Expand Up @@ -339,8 +351,8 @@ class AssignableRewriter : public CallExprRewriter {
requestFeature(HelperFeatureEnum::device_ext);
}

std::optional<std::string> rewrite() override {
std::optional<std::string> &&Result = Inner->rewrite();
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
std::optional<std::string> &&Result = Inner->rewrite(Analysis);
if (Result.has_value()) {
if ((CheckAssigned && IsAssigned) || (CheckInRetStmt && IsInRetStmt)) {
if (UseDpctCheckError) {
Expand Down Expand Up @@ -372,8 +384,8 @@ class InsertAroundRewriter : public CallExprRewriter {
: CallExprRewriter(C, ""), Prefix(Prefix), Suffix(Suffix),
Inner(InnerRewriter) {}

std::optional<std::string> rewrite() override {
std::optional<std::string> &&Result = Inner->rewrite();
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
std::optional<std::string> &&Result = Inner->rewrite(Analysis);
if (Result.has_value())
return Prefix + Result.value() + Suffix;
return Result;
Expand All @@ -391,7 +403,7 @@ class RemoveAPIRewriter : public CallExprRewriter {
: CallExprRewriter(C, CalleeName), IsAssigned(isAssigned(C)),
CalleeName(CalleeName), Message(Message) {}

std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
std::string Msg =
Message.empty() ? "this functionality is redundant in SYCL." : Message;
if (IsAssigned) {
Expand Down Expand Up @@ -424,10 +436,10 @@ class IfElseRewriter : public CallExprRewriter {
Indent = getIndent(getStmtExpansionSourceRange(C).getBegin(), SM);
}

std::optional<std::string> rewrite() override {
std::optional<std::string> &&PredStr = Pred->rewrite();
std::optional<std::string> &&IfBlockStr = IfBlock->rewrite();
std::optional<std::string> &&ElseBlockStr = ElseBlock->rewrite();
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
std::optional<std::string> &&PredStr = Pred->rewrite(Analysis);
std::optional<std::string> &&IfBlockStr = IfBlock->rewrite(Analysis);
std::optional<std::string> &&ElseBlockStr = ElseBlock->rewrite(Analysis);
return "if(" + PredStr.value() + "){" + NL.str() + Indent.str() +
Indent.str() + IfBlockStr.value() + ";" + NL.str() +
Indent.str() + "} else {" + NL.str() + Indent.str() + Indent.str() +
Expand Down Expand Up @@ -555,7 +567,7 @@ class FuncCallExprRewriter : public CallExprRewriter {
public:
virtual ~FuncCallExprRewriter() {}

virtual std::optional<std::string> rewrite() override;
virtual std::optional<std::string> rewrite(ExprAnalysis *Analysis) override;

friend FuncCallExprRewriterFactory;

Expand All @@ -581,7 +593,7 @@ class NoRewriteFuncNameRewriter : public CallExprRewriter {
NoRewrite = true;
}

std::optional<std::string> rewrite() override { return NewFuncName; }
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override { return NewFuncName; }
};

struct ThrustFunctor {
Expand Down Expand Up @@ -1175,7 +1187,8 @@ template <class ArgT> class DeleterCallExprRewriter : public CallExprRewriter {
DeleterCallExprRewriter(const CallExpr *C, StringRef Source,
std::function<ArgT(const CallExpr *)> ArgCreator)
: CallExprRewriter(C, Source), Arg(ArgCreator(C)) {}
std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
ParentAnalysisGuard Guard(Analysis);
std::string Result;
llvm::raw_string_ostream OS(Result);
OS << "delete ";
Expand All @@ -1191,7 +1204,8 @@ template <class ArgT> class ToStringExprRewriter : public CallExprRewriter {
ToStringExprRewriter(const CallExpr *C, StringRef Source,
std::function<ArgT(const CallExpr *)> ArgCreator)
: CallExprRewriter(C, Source), Arg(ArgCreator(C)) {}
std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
ParentAnalysisGuard Guard(Analysis);
std::string Result;
llvm::raw_string_ostream OS(Result);
print(OS, Arg);
Expand Down Expand Up @@ -1375,7 +1389,8 @@ class PrinterRewriter : Printer, public CallExprRewriter {
PrinterRewriter(const CallExpr *C, StringRef Source,
const std::function<ArgsT(const CallExpr *)> &...ArgCreators)
: PrinterRewriter(C, Source, ArgCreators(C)...) {}
std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
ParentAnalysisGuard Guard(Analysis);
std::string Result;
llvm::raw_string_ostream OS(Result);
Printer::print(OS);
Expand All @@ -1398,7 +1413,8 @@ class PrinterRewriter<MultiStmtsPrinter<StmtPrinters...>>
const CallExpr *C, StringRef Source,
const std::function<StmtPrinters(const CallExpr *)> &...PrinterCreators)
: PrinterRewriter(C, Source, PrinterCreators(C)...) {}
std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
ParentAnalysisGuard Guard(Analysis);
std::string Result;
llvm::raw_string_ostream OS(Result);
Base::print(OS);
Expand Down Expand Up @@ -1479,7 +1495,8 @@ class SimpleCallExprRewriter : public CallExprRewriter {
const std::function<CallExprPrinter<CalleeT, ArgsT...>(const CallExpr *)>
&PrinterFunctor)
: CallExprRewriter(C, Source), Printer(PrinterFunctor(C)) {}
std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
ParentAnalysisGuard Guard(Analysis);
std::string Result;
llvm::raw_string_ostream OS(Result);
Printer.print(OS);
Expand Down Expand Up @@ -1582,7 +1599,7 @@ class UnsupportFunctionRewriter : public CallExprRewriter {
report(MsgID, false, getMsgArg(Args, CE)...);
}

std::optional<std::string> rewrite() override { return std::nullopt; }
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override { return std::nullopt; }

friend UnsupportFunctionRewriterFactory<MsgArgs...>;
};
Expand All @@ -1609,7 +1626,7 @@ class UserDefinedRewriter : public CallExprRewriter {
buildRewriterStr(Call, OS, OB);
OS.flush();
}
std::optional<std::string> rewrite() override {
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override {
return ResultStr;
}

Expand Down Expand Up @@ -1701,7 +1718,7 @@ class UserDefinedRewriter : public CallExprRewriter {
struct NullRewriter : public CallExprRewriter {
NullRewriter(const CallExpr *C, StringRef Name) : CallExprRewriter(C, Name) {}

std::optional<std::string> rewrite() override { return std::nullopt; }
std::optional<std::string> rewrite(ExprAnalysis *Analysis) override { return std::nullopt; }
};

struct NullRewriterFactory : public CallExprRewriterFactoryBase {
Expand Down
95 changes: 83 additions & 12 deletions clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ std::pair<size_t, size_t> ExprAnalysis::getOffsetAndLength(const Expr *E, Source
} else {
// If the Expr is FileID or is macro arg
// e.g. CALL(expr)
auto Range = getStmtExpansionSourceRange(E);
auto Range = getDefinitionRange(E->getBeginLoc(), E->getEndLoc());
BeginLoc = Range.getBegin();
EndLoc = Range.getEnd();
End = getOffset(EndLoc) + Lexer::MeasureTokenLength(EndLoc, SM, Context.getLangOpts());
Expand Down Expand Up @@ -885,7 +885,7 @@ void ExprAnalysis::analyzeExpr(const CallExpr *CE) {
}

auto Rewriter = Itr->second->create(CE);
auto Result = Rewriter->rewrite();
auto Result = Rewriter->rewrite(this);
BlockLevelFormatFlag = Rewriter->getBlockLevelFormatFlag();

if (Rewriter->isNoRewrite()) {
Expand Down Expand Up @@ -1001,7 +1001,7 @@ void ExprAnalysis::analyzeExpr(const CXXMemberCallExpr *CMCE) {
BaseType + "." + MethodName);
if (Itr != CallExprRewriterFactoryBase::MethodRewriterMap->end()) {
auto Rewriter = Itr->second->create(CMCE);
auto Result = Rewriter->rewrite();
auto Result = Rewriter->rewrite(this);
if (Result.has_value()) {
auto ResultStr = Result.value();
addReplacement(CMCE, ResultStr);
Expand Down Expand Up @@ -1271,9 +1271,57 @@ void ExprAnalysis::applyAllSubExprRepl() {
for (std::shared_ptr<ExtReplacement> Repl : SubExprRepl) {
if (BlockLevelFormatFlag)
Repl->setBlockLevelFormatFlag();

DpctGlobalInfo::getInstance().addReplacement(Repl);
}
SubExprRepl.clear();
}

bool needCleanUp(const Expr *E){
return DpctGlobalInfo::findAncestor<CompoundStmt>(
E, [&](const DynTypedNode &Node) {
return Node.get<CompoundStmt>() || !Node.get<ExprWithCleanups>();
});
}

TextModification *removeWithCleanUp(SourceLocation Begin, unsigned Length,
const SourceManager &SM) {
Token NextToken;
if (!Lexer::getRawToken(Begin.getLocWithOffset(Length), NextToken, SM,
DpctGlobalInfo::getContext().getLangOpts(), true) &&
NextToken.getKind() == tok::semi) {
Length += NextToken.getLength();
auto BeginData = SM.getCharacterData(Begin),
EndData = BeginData + Length;
unsigned Indent = 0, Trailing = 0;
auto Ch= *--BeginData;
while (isspace(Ch) && Ch != '\n' && Ch != '\r' && Ch != '{') {
++Indent;
Ch = *--BeginData;
}
Ch = *EndData;
while (isspace(Ch) || Ch == '\\') {
++Trailing;
if (Ch == '\n' || Ch == '\r') {
Ch = *++EndData;
if (Ch == '\n' || Ch == '\r')
++Trailing;
break;
}
Ch = *++EndData;
}
Begin = Begin.getLocWithOffset(-Indent);
Length += Indent + Trailing;
}
return new ReplaceText(Begin, Length, "");
}

TextModification *ExprAnalysis::getReplacement() {
if (!hasReplacement())
return nullptr;
std::string Repl = getReplacedString();
if (E && Repl.empty() && needCleanUp(E))
return removeWithCleanUp(SrcBeginLoc, SrcLength, SM);
return new ReplaceText(SrcBeginLoc, SrcLength, std::move(Repl));
}

const std::string &ArgumentAnalysis::getDefaultArgument(const Expr *E) {
Expand Down Expand Up @@ -2168,6 +2216,25 @@ void KernelConfigAnalysis::analyze(const Expr *E, unsigned int Idx,
ArgumentAnalysis::analyze(E);
}

void ExprAnalysis::applySubExprReplToParent() {
if (auto Parent = CallExprRewriter::getParentAnalysis()) {
for (const auto &Repl : SubExprRepl) {
auto File = SM.getFileManager().getFileRef(Repl->getFilePath());
if (!File || Parent->FileId != SM.translateFile(File.get()) ||
Repl->getOffset() < Parent->SrcBegin ||
Repl->getOffset() + Repl->getLength() >
Parent->SrcBegin + Parent->SrcLength) {
Parent->addExtReplacement(Repl);
} else {
Parent->addReplacement(Repl->getOffset() - Parent->SrcBegin,
Repl->getLength(),
Repl->getReplacementText().str());
}
}
SubExprRepl.clear();
}
}

std::string ArgumentAnalysis::getRewriteString() {
// Find rewrite range
auto RewriteRange = getLocInCallSpelling(getTargetExpr());
Expand All @@ -2182,16 +2249,20 @@ std::string ArgumentAnalysis::getRewriteString() {

StringReplacements SRs;
SRs.init(std::move(OriginalStr));
for (std::shared_ptr<ExtReplacement> SubRepl : SubExprRepl) {
if (isInRange(RewriteRangeBegin, RewriteRangeEnd, SubRepl->getFilePath(),
SubRepl->getOffset()) &&
isInRange(RewriteRangeBegin, RewriteRangeEnd, SubRepl->getFilePath(),
SubRepl->getOffset() + SubRepl->getLength())) {
SRs.addStringReplacement(
SubRepl->getOffset() - SM.getDecomposedLoc(RewriteRangeBegin).second,
SubRepl->getLength(), SubRepl->getReplacementText().str());
for (auto Iter = SubExprRepl.begin(); Iter != SubExprRepl.end();) {
auto &Repl = **Iter;
if (isInRange(RewriteRangeBegin, RewriteRangeEnd, Repl.getFilePath(),
Repl.getOffset()) &&
isInRange(RewriteRangeBegin, RewriteRangeEnd, Repl.getFilePath(),
Repl.getOffset() + Repl.getLength())) {
SRs.addStringReplacement(Repl.getOffset() - DL.second, Repl.getLength(),
Repl.getReplacementText().str());
Iter = SubExprRepl.erase(Iter);
} else {
++Iter;
}
}
applySubExprReplToParent();
return SRs.getReplacedString();
}

Expand Down
14 changes: 2 additions & 12 deletions clang/lib/DPCT/RuleInfra/ExprAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,7 @@ class ExprAnalysis {
// This function is not re-enterable, if caller need to check if it returns
// nullptr, caller need to use temp variable to save the return value, then
// check. Don't call twice for same Replacement.
inline TextModification *getReplacement() {
bool hasRepl = hasReplacement();
std::string Repl = getReplacedString();
if (E) {
auto Range = getDefinitionRange(E->getBeginLoc(), E->getEndLoc());
if (!isSameLocation(Range.getBegin(), Range.getEnd())) {
return hasRepl ? new ReplaceStmt(E, true, Repl) : nullptr;
}
}
return hasRepl ? new ReplaceText(SrcBeginLoc, SrcLength, std::move(Repl))
: nullptr;
}
TextModification *getReplacement();

inline void clearReplacement() { ReplSet.reset(); }

Expand Down Expand Up @@ -385,6 +374,7 @@ class ExprAnalysis {
}

void applyAllSubExprRepl();
void applySubExprReplToParent();
inline std::vector<std::shared_ptr<ExtReplacement>> &getSubExprRepl() {
return SubExprRepl;
};
Expand Down
Loading
Loading