diff --git a/clang/lib/DPCT/AnalysisInfo.cpp b/clang/lib/DPCT/AnalysisInfo.cpp index fd77cacb9d14..310834529023 100644 --- a/clang/lib/DPCT/AnalysisInfo.cpp +++ b/clang/lib/DPCT/AnalysisInfo.cpp @@ -2821,11 +2821,12 @@ void CtTypeInfo::setArrayInfo(const DependentSizedArrayTypeLoc &TL, bool NeedSizeFold) { ContainSizeofType = containSizeOfType(TL.getSizeExpr()); ExprAnalysis EA; + EA.IsAnalyzingCtTypeInfo = true; EA.analyze(TL.getSizeExpr()); auto TDSI = EA.getTemplateDependentStringInfo(); if (TDSI->containsTemplateDependentMacro()) TemplateDependentMacro = true; - Range.emplace_back(EA.getTemplateDependentStringInfo()); + Range.emplace_back(TDSI); setTypeInfo(TL.getElementLoc(), NeedSizeFold); } void CtTypeInfo::setArrayInfo(const IncompleteArrayTypeLoc &TL, diff --git a/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp b/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp index acd81303760c..da94e4dcbc23 100644 --- a/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp +++ b/clang/lib/DPCT/RuleInfra/ExprAnalysis.cpp @@ -412,7 +412,27 @@ void ExprAnalysis::initSourceRange(const SourceRange &Range) { } void StringReplacements::replaceString() { - SourceStr.reserve(SourceStr.length() + ShiftLength); + for (auto &TDR : TDRs2) { + // Find items in ReplMap whose offset <= TDR.first. + // Then check length, is there is overlap, ignore this insertion. + // Finally calculate the shift length. + int Shift = 0; + auto UpperBound = ReplMap.upper_bound(TDR.first); + for (auto It = ReplMap.begin(); It != UpperBound; ++It) { + if ((It->first + It->second->getLength()) > TDR.first) { + // overlap + continue; + } + Shift += + (It->second->getReplacedText().length() - It->second->getLength()); + } + auto NewTDR = std::make_shared( + TDR.second->getSourceStr(), TDR.first, TDR.second->getLength(), + TDR.second->getTemplateIndex()); + NewTDR->shift(Shift); + TDRs.insert(std::make_pair(TDR.first + Shift, NewTDR)); + } + auto Itr = ReplMap.rbegin(); while (Itr != ReplMap.rend()) { Itr->second->replaceString(); @@ -571,7 +591,16 @@ void ExprAnalysis::analyzeExpr(const DeclRefExpr *DRE) { } if (auto TemplateDecl = dyn_cast(DRE->getDecl())) addReplacement(DRE, TemplateDecl->getIndex()); - else if (auto ECD = dyn_cast(DRE->getDecl())) { + else if (const auto *VD = dyn_cast(DRE->getDecl()); + VD && VD->isConstexpr() && IsAnalyzingCtTypeInfo) { + if (VD->getInit() && VD->getInit()->getBeginLoc().isValid() && + (VD->getInit()->getDependence() != ExprDependence::None)) { + ExprAnalysis EA(VD->getInit()); + auto TDSI = EA.getTemplateDependentStringInfo(); + auto LocInfo = getOffsetAndLength(DRE); + addReplacement(LocInfo.first, LocInfo.second, TDSI); + } + } else if (auto ECD = dyn_cast(DRE->getDecl())) { std::unordered_set targetStr = { "thread_scope_system", "thread_scope_device", "thread_scope_block", "memory_order_relaxed", "memory_order_acq_rel", "memory_order_release", diff --git a/clang/lib/DPCT/RuleInfra/ExprAnalysis.h b/clang/lib/DPCT/RuleInfra/ExprAnalysis.h index a67924a4500d..e7fa74174f21 100644 --- a/clang/lib/DPCT/RuleInfra/ExprAnalysis.h +++ b/clang/lib/DPCT/RuleInfra/ExprAnalysis.h @@ -40,6 +40,7 @@ class StringReplacement { } inline const std::string &getReplacedText() { return Text; } + inline size_t getLength() { return Length; } private: // SourceStr is the string which need replaced. @@ -56,14 +57,14 @@ class TemplateArgumentInfo; /// Store replacement dependent on template args class TemplateDependentReplacement { - std::string &SourceStr; + std::string SourceStr; size_t Offset; size_t Length; unsigned TemplateIndex; public: - TemplateDependentReplacement(std::string &SrcStr, size_t Offset, - size_t Length, unsigned TemplateIndex) + TemplateDependentReplacement(std::string SrcStr, size_t Offset, size_t Length, + unsigned TemplateIndex) : SourceStr(SrcStr), Offset(Offset), Length(Length), TemplateIndex(TemplateIndex) {} TemplateDependentReplacement(const TemplateDependentReplacement &rhs) @@ -71,12 +72,14 @@ class TemplateDependentReplacement { rhs.TemplateIndex) {} inline std::shared_ptr - alterSource(std::string &SrcStr) { + alterSource(std::string SrcStr) { return std::make_shared( SrcStr, Offset, Length, TemplateIndex); } + inline const std::string &getSourceStr() const { return SourceStr; } inline size_t getOffset() const { return Offset; } inline size_t getLength() const { return Length; } + inline size_t getTemplateIndex() const { return TemplateIndex; } const TemplateArgumentInfo & getTargetArgument(const std::vector &TemplateList); void replace(const std::vector &TemplateList); @@ -117,12 +120,16 @@ class TemplateDependentStringInfo { HelperFeatureSet = Set; } bool containsTemplateDependentMacro() const { return ContainsTemplateDependentMacro; } + const std::vector> & + getTDRs() const { + return TDRs; + } }; /// Store an expr source string which may need replaced and its replacements class StringReplacements { public: - StringReplacements() : ShiftLength(0) {} + StringReplacements() {} inline void init(std::string &&SrcStr) { SourceStr = std::move(SrcStr); ReplMap.clear(); @@ -130,26 +137,32 @@ class StringReplacements { inline void reset() { ReplMap.clear(); } // Add a template dependent replacement - inline void addTemplateDependentReplacement(size_t Offset, size_t Length, - unsigned TemplateIndex) { - TDRs.insert( + void addTemplateDependentReplacement(size_t Offset, size_t Length, + unsigned TemplateIndex) { + TDRs2.insert( std::make_pair(Offset, std::make_shared( SourceStr, Offset, Length, TemplateIndex))); } + + inline void addTemplateDependentReplacement( + size_t Offset, size_t Length, + std::shared_ptr TDSI) { + addStringReplacement(Offset, Length, TDSI->getSourceString()); + for (const auto &Item : TDSI->getTDRs()) { + std::string String = + TDSI->getSourceString().substr(Item->getOffset(), Item->getLength()); + size_t NewOffset = Offset + Item->getOffset(); + auto TDR = std::make_shared( + String, NewOffset, String.size(), Item->getTemplateIndex()); + TDRs.insert(std::make_pair(NewOffset, TDR)); + } + } + // Add a string replacement void addStringReplacement(size_t Offset, size_t Length, std::string Text) { - auto Result = ReplMap.insert(std::make_pair( + ReplMap.insert(std::make_pair( Offset, std::make_shared(SourceStr, Offset, Length, Text))); - if (Result.second) { - auto Shift = Result.first->second->getReplacedText().length() - Length; - ShiftLength += Shift; - auto TDRItr = TDRs.upper_bound(Result.first->first); - while (TDRItr != TDRs.end()) { - TDRItr->second->shift(Shift); - ++TDRItr; - } - } } // Generate replacement text info which dependent on template args. @@ -171,10 +184,10 @@ class StringReplacements { void replaceString(); - unsigned ShiftLength; std::string SourceStr; std::map> ReplMap; std::map> TDRs; + std::map> TDRs2; }; /// Analyze expression and generate its migrated string @@ -586,6 +599,12 @@ class ExprAnalysis { ReplSet.addTemplateDependentReplacement(Offset, Length, TemplateIndex); } + inline void + addReplacement(size_t Offset, size_t Length, + std::shared_ptr TDSI) { + ReplSet.addTemplateDependentReplacement(Offset, Length, TDSI); + } + // Analyze the expression, jump to corresponding analysis function according // to its class // Precondition: Expression != nullptr @@ -691,6 +710,9 @@ class ExprAnalysis { std::string RewritePrefix; std::string RewritePostfix; std::set HelperFeatureSet; + +public: + bool IsAnalyzingCtTypeInfo = false; }; // Analyze pointer allocated by cudaMallocManaged. diff --git a/clang/test/dpct/sharedmem_var_static.cu b/clang/test/dpct/sharedmem_var_static.cu index 6aac6e486d52..913c066e4d38 100644 --- a/clang/test/dpct/sharedmem_var_static.cu +++ b/clang/test/dpct/sharedmem_var_static.cu @@ -4,11 +4,12 @@ // RUN: dpct --format-range=none --usm-level=none -out-root %T/sharedmem_var_static %s --cuda-include-path="%cuda-path/include" --sycl-named-lambda -- -x cuda --cuda-host-only // RUN: FileCheck %s --match-full-lines --input-file %T/sharedmem_var_static/sharedmem_var_static.dp.cpp // RUN: %if build_lit %{icpx -c -fsycl -DNO_BUILD_TEST %T/sharedmem_var_static/sharedmem_var_static.dp.cpp -o %T/sharedmem_var_static/sharedmem_var_static.dp.o %} -#ifndef NO_BUILD_TEST + #include #include #define SIZE 64 +#ifndef NO_BUILD_TEST class TestObject{ public: // CHECK: static void run(int *in, int *out, int &a0) { @@ -224,3 +225,15 @@ void fooh() { fook<<<1, 1>>>(); } #endif + +constexpr int kWarpSize = 32; + +template __global__ void kerfunc() { + constexpr int kNumWarps = (2 * ThreadsPerBlock / kWarpSize * ccc); + __shared__ int smem[kNumWarps * NumWarpQ]; +} + +void foo2() { + // CHECK: sycl::local_accessor smem_acc_ct1(sycl::range<1>((2 * 128 / kWarpSize * 16) * 8), cgh); + kerfunc<128, 8, 16><<<32, 32>>>(); +}