|
15 | 15 | #include "clang/AST/QualTypeNames.h"
|
16 | 16 | #include "clang/AST/RecordLayout.h"
|
17 | 17 | #include "clang/AST/RecursiveASTVisitor.h"
|
18 |
| -#include "clang/AST/TemplateArgumentVisitor.h" |
19 |
| -#include "clang/AST/Mangle.h" |
20 | 18 | #include "clang/AST/SYCLKernelInfo.h"
|
21 | 19 | #include "clang/AST/StmtSYCL.h"
|
| 20 | +#include "clang/AST/TemplateArgumentVisitor.h" |
22 | 21 | #include "clang/AST/TypeOrdering.h"
|
23 | 22 | #include "clang/AST/TypeVisitor.h"
|
24 | 23 | #include "clang/Analysis/CallGraph.h"
|
|
27 | 26 | #include "clang/Basic/Diagnostic.h"
|
28 | 27 | #include "clang/Basic/TargetInfo.h"
|
29 | 28 | #include "clang/Basic/Version.h"
|
30 |
| -#include "clang/AST/SYCLKernelInfo.h" |
31 | 29 | #include "clang/Sema/Attr.h"
|
32 | 30 | #include "clang/Sema/Initialization.h"
|
33 | 31 | #include "clang/Sema/ParsedAttr.h"
|
@@ -6426,6 +6424,120 @@ static void EmitPragmaDiagnosticPop(raw_ostream &O) {
|
6426 | 6424 | O << "\n";
|
6427 | 6425 | }
|
6428 | 6426 |
|
| 6427 | +template <typename BeforeFn, typename AfterFn> |
| 6428 | +static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS, |
| 6429 | + const DeclContext *DC) { |
| 6430 | + if (DC->isTranslationUnit()) |
| 6431 | + return; |
| 6432 | + |
| 6433 | + const auto *CurDecl = cast<Decl>(DC); |
| 6434 | + // Ensure we are in the canonical version, so that we know we have the 'full' |
| 6435 | + // name of the thing. |
| 6436 | + CurDecl = CurDecl->getCanonicalDecl(); |
| 6437 | + |
| 6438 | + // We are intentionally skipping linkage decls and record decls. Namespaces |
| 6439 | + // can appear in a linkage decl, but not a record decl, so we don't have to |
| 6440 | + // worry about the names getting messed up from that. We handle record decls |
| 6441 | + // later when printing the name of the thing. |
| 6442 | + const auto *NS = dyn_cast<NamespaceDecl>(CurDecl); |
| 6443 | + if (NS) |
| 6444 | + Before(OS, NS); |
| 6445 | + |
| 6446 | + if (const DeclContext *NewDC = CurDecl->getDeclContext()) |
| 6447 | + PrintNSHelper(Before, After, OS, NewDC); |
| 6448 | + |
| 6449 | + if (NS) |
| 6450 | + After(OS, NS); |
| 6451 | +} |
| 6452 | + |
| 6453 | +static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC, |
| 6454 | + bool isPrintNamesOnly = false) { |
| 6455 | + PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {}, |
| 6456 | + [isPrintNamesOnly](raw_ostream &OS, const NamespaceDecl *NS) { |
| 6457 | + if (!isPrintNamesOnly) { |
| 6458 | + if (NS->isInline()) |
| 6459 | + OS << "inline "; |
| 6460 | + OS << "namespace "; |
| 6461 | + } |
| 6462 | + if (!NS->isAnonymousNamespace()) { |
| 6463 | + OS << NS->getName(); |
| 6464 | + if (isPrintNamesOnly) |
| 6465 | + OS << "::"; |
| 6466 | + else |
| 6467 | + OS << " "; |
| 6468 | + } |
| 6469 | + if (!isPrintNamesOnly) { |
| 6470 | + OS << "{\n"; |
| 6471 | + } |
| 6472 | + }, |
| 6473 | + OS, DC); |
| 6474 | +} |
| 6475 | + |
| 6476 | +static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { |
| 6477 | + PrintNSHelper( |
| 6478 | + [](raw_ostream &OS, const NamespaceDecl *NS) { |
| 6479 | + OS << "} // "; |
| 6480 | + if (NS->isInline()) |
| 6481 | + OS << "inline "; |
| 6482 | + |
| 6483 | + OS << "namespace "; |
| 6484 | + if (!NS->isAnonymousNamespace()) |
| 6485 | + OS << NS->getName(); |
| 6486 | + |
| 6487 | + OS << '\n'; |
| 6488 | + }, |
| 6489 | + [](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC); |
| 6490 | +} |
| 6491 | + |
| 6492 | +class FreeFunctionPrinter { |
| 6493 | + raw_ostream &O; |
| 6494 | + const PrintingPolicy &Policy; |
| 6495 | + bool NSInserted = false; |
| 6496 | + |
| 6497 | +public: |
| 6498 | + FreeFunctionPrinter(raw_ostream &O, const PrintingPolicy &Policy) |
| 6499 | + : O(O), Policy(Policy) {} |
| 6500 | + |
| 6501 | + /// Emits the function declaration of a free function. |
| 6502 | + /// \param FD The function declaration to print. |
| 6503 | + /// \param Args The arguments of the function. |
| 6504 | + void printFreeFunctionDeclaration(const FunctionDecl *FD, |
| 6505 | + const std::string &Args) { |
| 6506 | + const DeclContext *DC = FD->getDeclContext(); |
| 6507 | + if (DC) { |
| 6508 | + // if function in namespace, print namespace |
| 6509 | + if (isa<NamespaceDecl>(DC)) { |
| 6510 | + PrintNamespaces(O, FD); |
| 6511 | + // Set flag to print closing braces for namespaces and namespace in shim |
| 6512 | + // function |
| 6513 | + NSInserted = true; |
| 6514 | + } |
| 6515 | + O << FD->getReturnType().getAsString() << " "; |
| 6516 | + O << FD->getNameAsString() << "(" << Args << ");"; |
| 6517 | + if (NSInserted) { |
| 6518 | + O << "\n"; |
| 6519 | + PrintNSClosingBraces(O, FD); |
| 6520 | + } |
| 6521 | + O << "\n"; |
| 6522 | + } |
| 6523 | + } |
| 6524 | + |
| 6525 | + /// Emits free function shim function. |
| 6526 | + /// \param FD The function declaration to print. |
| 6527 | + /// \param ShimCounter The counter for the shim function. |
| 6528 | + /// \param ParmList The parameter list of the function. |
| 6529 | + void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter, |
| 6530 | + const std::string &ParmList) { |
| 6531 | + // Generate a shim function that returns the address of the free function. |
| 6532 | + O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n"; |
| 6533 | + O << " return (void (*)(" << ParmList << "))"; |
| 6534 | + |
| 6535 | + if (NSInserted) |
| 6536 | + PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true); |
| 6537 | + O << FD->getIdentifier()->getName().data(); |
| 6538 | + } |
| 6539 | +}; |
| 6540 | + |
6429 | 6541 | void SYCLIntegrationHeader::emit(raw_ostream &O) {
|
6430 | 6542 | O << "// This is auto-generated SYCL integration header.\n";
|
6431 | 6543 | O << "\n";
|
@@ -6714,16 +6826,25 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
|
6714 | 6826 | if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
|
6715 | 6827 | O << "extern \"C\" ";
|
6716 | 6828 | std::string ParmList;
|
| 6829 | + std::string ParmListWithNames; |
6717 | 6830 | bool FirstParam = true;
|
6718 | 6831 | Policy.SuppressDefaultTemplateArgs = false;
|
6719 | 6832 | Policy.PrintCanonicalTypes = true;
|
| 6833 | + llvm::raw_string_ostream ParmListWithNamesOstream{ParmListWithNames}; |
6720 | 6834 | for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
|
6721 | 6835 | if (FirstParam)
|
6722 | 6836 | FirstParam = false;
|
6723 |
| - else |
| 6837 | + else { |
6724 | 6838 | ParmList += ", ";
|
| 6839 | + ParmListWithNamesOstream << ", "; |
| 6840 | + } |
| 6841 | + Policy.SuppressTagKeyword = true; |
| 6842 | + Param->getType().print(ParmListWithNamesOstream, Policy); |
| 6843 | + Policy.SuppressTagKeyword = false; |
| 6844 | + ParmListWithNamesOstream << " " << Param->getNameAsString(); |
6725 | 6845 | ParmList += Param->getType().getCanonicalType().getAsString(Policy);
|
6726 | 6846 | }
|
| 6847 | + ParmListWithNamesOstream.flush(); |
6727 | 6848 | FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
|
6728 | 6849 | Policy.PrintCanonicalTypes = false;
|
6729 | 6850 | Policy.SuppressDefinition = true;
|
@@ -6757,17 +6878,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
|
6757 | 6878 | // template arguments that match default template arguments while printing
|
6758 | 6879 | // template-ids, even if the source code doesn't reference them.
|
6759 | 6880 | Policy.EnforceDefaultTemplateArgs = true;
|
| 6881 | + FreeFunctionPrinter FFPrinter(O, Policy); |
6760 | 6882 | if (FTD) {
|
6761 | 6883 | FTD->print(O, Policy);
|
| 6884 | + O << ";\n"; |
6762 | 6885 | } else {
|
6763 |
| - K.SyclKernel->print(O, Policy); |
| 6886 | + FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames); |
6764 | 6887 | }
|
6765 |
| - O << ";\n"; |
6766 | 6888 |
|
6767 |
| - // Generate a shim function that returns the address of the free function. |
6768 |
| - O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n"; |
6769 |
| - O << " return (void (*)(" << ParmList << "))" |
6770 |
| - << K.SyclKernel->getIdentifier()->getName().data(); |
| 6889 | + FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList); |
6771 | 6890 | if (FTD) {
|
6772 | 6891 | const TemplateArgumentList *TAL =
|
6773 | 6892 | K.SyclKernel->getTemplateSpecializationArgs();
|
@@ -6936,61 +7055,6 @@ bool SYCLIntegrationFooter::emit(StringRef IntHeaderName) {
|
6936 | 7055 | return emit(Out);
|
6937 | 7056 | }
|
6938 | 7057 |
|
6939 |
| -template <typename BeforeFn, typename AfterFn> |
6940 |
| -static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS, |
6941 |
| - const DeclContext *DC) { |
6942 |
| - if (DC->isTranslationUnit()) |
6943 |
| - return; |
6944 |
| - |
6945 |
| - const auto *CurDecl = cast<Decl>(DC); |
6946 |
| - // Ensure we are in the canonical version, so that we know we have the 'full' |
6947 |
| - // name of the thing. |
6948 |
| - CurDecl = CurDecl->getCanonicalDecl(); |
6949 |
| - |
6950 |
| - // We are intentionally skipping linkage decls and record decls. Namespaces |
6951 |
| - // can appear in a linkage decl, but not a record decl, so we don't have to |
6952 |
| - // worry about the names getting messed up from that. We handle record decls |
6953 |
| - // later when printing the name of the thing. |
6954 |
| - const auto *NS = dyn_cast<NamespaceDecl>(CurDecl); |
6955 |
| - if (NS) |
6956 |
| - Before(OS, NS); |
6957 |
| - |
6958 |
| - if (const DeclContext *NewDC = CurDecl->getDeclContext()) |
6959 |
| - PrintNSHelper(Before, After, OS, NewDC); |
6960 |
| - |
6961 |
| - if (NS) |
6962 |
| - After(OS, NS); |
6963 |
| -} |
6964 |
| - |
6965 |
| -static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC) { |
6966 |
| - PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {}, |
6967 |
| - [](raw_ostream &OS, const NamespaceDecl *NS) { |
6968 |
| - if (NS->isInline()) |
6969 |
| - OS << "inline "; |
6970 |
| - OS << "namespace "; |
6971 |
| - if (!NS->isAnonymousNamespace()) |
6972 |
| - OS << NS->getName() << " "; |
6973 |
| - OS << "{\n"; |
6974 |
| - }, |
6975 |
| - OS, DC); |
6976 |
| -} |
6977 |
| - |
6978 |
| -static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) { |
6979 |
| - PrintNSHelper( |
6980 |
| - [](raw_ostream &OS, const NamespaceDecl *NS) { |
6981 |
| - OS << "} // "; |
6982 |
| - if (NS->isInline()) |
6983 |
| - OS << "inline "; |
6984 |
| - |
6985 |
| - OS << "namespace "; |
6986 |
| - if (!NS->isAnonymousNamespace()) |
6987 |
| - OS << NS->getName(); |
6988 |
| - |
6989 |
| - OS << '\n'; |
6990 |
| - }, |
6991 |
| - [](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC); |
6992 |
| -} |
6993 |
| - |
6994 | 7058 | static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter,
|
6995 | 7059 | const std::string &LastShim,
|
6996 | 7060 | const NamespaceDecl *AnonNS) {
|
|
0 commit comments