Skip to content

Commit 72bbe64

Browse files
SYCL free function namespace support (#17585)
SYCL free function docs: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc Changes should generate right forward declarations of any function(not template) and shim functions in namespace or not.
1 parent 49c6004 commit 72bbe64

File tree

5 files changed

+891
-82
lines changed

5 files changed

+891
-82
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 129 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
#include "clang/AST/QualTypeNames.h"
1616
#include "clang/AST/RecordLayout.h"
1717
#include "clang/AST/RecursiveASTVisitor.h"
18-
#include "clang/AST/TemplateArgumentVisitor.h"
19-
#include "clang/AST/Mangle.h"
2018
#include "clang/AST/SYCLKernelInfo.h"
2119
#include "clang/AST/StmtSYCL.h"
20+
#include "clang/AST/TemplateArgumentVisitor.h"
2221
#include "clang/AST/TypeOrdering.h"
2322
#include "clang/AST/TypeVisitor.h"
2423
#include "clang/Analysis/CallGraph.h"
@@ -27,7 +26,6 @@
2726
#include "clang/Basic/Diagnostic.h"
2827
#include "clang/Basic/TargetInfo.h"
2928
#include "clang/Basic/Version.h"
30-
#include "clang/AST/SYCLKernelInfo.h"
3129
#include "clang/Sema/Attr.h"
3230
#include "clang/Sema/Initialization.h"
3331
#include "clang/Sema/ParsedAttr.h"
@@ -6426,6 +6424,120 @@ static void EmitPragmaDiagnosticPop(raw_ostream &O) {
64266424
O << "\n";
64276425
}
64286426

6427+
template <typename BeforeFn, typename AfterFn>
6428+
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS,
6429+
const DeclContext *DC) {
6430+
if (DC->isTranslationUnit())
6431+
return;
6432+
6433+
const auto *CurDecl = cast<Decl>(DC);
6434+
// Ensure we are in the canonical version, so that we know we have the 'full'
6435+
// name of the thing.
6436+
CurDecl = CurDecl->getCanonicalDecl();
6437+
6438+
// We are intentionally skipping linkage decls and record decls. Namespaces
6439+
// can appear in a linkage decl, but not a record decl, so we don't have to
6440+
// worry about the names getting messed up from that. We handle record decls
6441+
// later when printing the name of the thing.
6442+
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl);
6443+
if (NS)
6444+
Before(OS, NS);
6445+
6446+
if (const DeclContext *NewDC = CurDecl->getDeclContext())
6447+
PrintNSHelper(Before, After, OS, NewDC);
6448+
6449+
if (NS)
6450+
After(OS, NS);
6451+
}
6452+
6453+
static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC,
6454+
bool isPrintNamesOnly = false) {
6455+
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {},
6456+
[isPrintNamesOnly](raw_ostream &OS, const NamespaceDecl *NS) {
6457+
if (!isPrintNamesOnly) {
6458+
if (NS->isInline())
6459+
OS << "inline ";
6460+
OS << "namespace ";
6461+
}
6462+
if (!NS->isAnonymousNamespace()) {
6463+
OS << NS->getName();
6464+
if (isPrintNamesOnly)
6465+
OS << "::";
6466+
else
6467+
OS << " ";
6468+
}
6469+
if (!isPrintNamesOnly) {
6470+
OS << "{\n";
6471+
}
6472+
},
6473+
OS, DC);
6474+
}
6475+
6476+
static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
6477+
PrintNSHelper(
6478+
[](raw_ostream &OS, const NamespaceDecl *NS) {
6479+
OS << "} // ";
6480+
if (NS->isInline())
6481+
OS << "inline ";
6482+
6483+
OS << "namespace ";
6484+
if (!NS->isAnonymousNamespace())
6485+
OS << NS->getName();
6486+
6487+
OS << '\n';
6488+
},
6489+
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
6490+
}
6491+
6492+
class FreeFunctionPrinter {
6493+
raw_ostream &O;
6494+
const PrintingPolicy &Policy;
6495+
bool NSInserted = false;
6496+
6497+
public:
6498+
FreeFunctionPrinter(raw_ostream &O, const PrintingPolicy &Policy)
6499+
: O(O), Policy(Policy) {}
6500+
6501+
/// Emits the function declaration of a free function.
6502+
/// \param FD The function declaration to print.
6503+
/// \param Args The arguments of the function.
6504+
void printFreeFunctionDeclaration(const FunctionDecl *FD,
6505+
const std::string &Args) {
6506+
const DeclContext *DC = FD->getDeclContext();
6507+
if (DC) {
6508+
// if function in namespace, print namespace
6509+
if (isa<NamespaceDecl>(DC)) {
6510+
PrintNamespaces(O, FD);
6511+
// Set flag to print closing braces for namespaces and namespace in shim
6512+
// function
6513+
NSInserted = true;
6514+
}
6515+
O << FD->getReturnType().getAsString() << " ";
6516+
O << FD->getNameAsString() << "(" << Args << ");";
6517+
if (NSInserted) {
6518+
O << "\n";
6519+
PrintNSClosingBraces(O, FD);
6520+
}
6521+
O << "\n";
6522+
}
6523+
}
6524+
6525+
/// Emits free function shim function.
6526+
/// \param FD The function declaration to print.
6527+
/// \param ShimCounter The counter for the shim function.
6528+
/// \param ParmList The parameter list of the function.
6529+
void printFreeFunctionShim(const FunctionDecl *FD, const unsigned ShimCounter,
6530+
const std::string &ParmList) {
6531+
// Generate a shim function that returns the address of the free function.
6532+
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
6533+
O << " return (void (*)(" << ParmList << "))";
6534+
6535+
if (NSInserted)
6536+
PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true);
6537+
O << FD->getIdentifier()->getName().data();
6538+
}
6539+
};
6540+
64296541
void SYCLIntegrationHeader::emit(raw_ostream &O) {
64306542
O << "// This is auto-generated SYCL integration header.\n";
64316543
O << "\n";
@@ -6714,16 +6826,25 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
67146826
if (K.SyclKernel->getLanguageLinkage() == CLanguageLinkage)
67156827
O << "extern \"C\" ";
67166828
std::string ParmList;
6829+
std::string ParmListWithNames;
67176830
bool FirstParam = true;
67186831
Policy.SuppressDefaultTemplateArgs = false;
67196832
Policy.PrintCanonicalTypes = true;
6833+
llvm::raw_string_ostream ParmListWithNamesOstream{ParmListWithNames};
67206834
for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
67216835
if (FirstParam)
67226836
FirstParam = false;
6723-
else
6837+
else {
67246838
ParmList += ", ";
6839+
ParmListWithNamesOstream << ", ";
6840+
}
6841+
Policy.SuppressTagKeyword = true;
6842+
Param->getType().print(ParmListWithNamesOstream, Policy);
6843+
Policy.SuppressTagKeyword = false;
6844+
ParmListWithNamesOstream << " " << Param->getNameAsString();
67256845
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
67266846
}
6847+
ParmListWithNamesOstream.flush();
67276848
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
67286849
Policy.PrintCanonicalTypes = false;
67296850
Policy.SuppressDefinition = true;
@@ -6757,17 +6878,15 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
67576878
// template arguments that match default template arguments while printing
67586879
// template-ids, even if the source code doesn't reference them.
67596880
Policy.EnforceDefaultTemplateArgs = true;
6881+
FreeFunctionPrinter FFPrinter(O, Policy);
67606882
if (FTD) {
67616883
FTD->print(O, Policy);
6884+
O << ";\n";
67626885
} else {
6763-
K.SyclKernel->print(O, Policy);
6886+
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames);
67646887
}
6765-
O << ";\n";
67666888

6767-
// Generate a shim function that returns the address of the free function.
6768-
O << "static constexpr auto __sycl_shim" << ShimCounter << "() {\n";
6769-
O << " return (void (*)(" << ParmList << "))"
6770-
<< K.SyclKernel->getIdentifier()->getName().data();
6889+
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList);
67716890
if (FTD) {
67726891
const TemplateArgumentList *TAL =
67736892
K.SyclKernel->getTemplateSpecializationArgs();
@@ -6936,61 +7055,6 @@ bool SYCLIntegrationFooter::emit(StringRef IntHeaderName) {
69367055
return emit(Out);
69377056
}
69387057

6939-
template <typename BeforeFn, typename AfterFn>
6940-
static void PrintNSHelper(BeforeFn Before, AfterFn After, raw_ostream &OS,
6941-
const DeclContext *DC) {
6942-
if (DC->isTranslationUnit())
6943-
return;
6944-
6945-
const auto *CurDecl = cast<Decl>(DC);
6946-
// Ensure we are in the canonical version, so that we know we have the 'full'
6947-
// name of the thing.
6948-
CurDecl = CurDecl->getCanonicalDecl();
6949-
6950-
// We are intentionally skipping linkage decls and record decls. Namespaces
6951-
// can appear in a linkage decl, but not a record decl, so we don't have to
6952-
// worry about the names getting messed up from that. We handle record decls
6953-
// later when printing the name of the thing.
6954-
const auto *NS = dyn_cast<NamespaceDecl>(CurDecl);
6955-
if (NS)
6956-
Before(OS, NS);
6957-
6958-
if (const DeclContext *NewDC = CurDecl->getDeclContext())
6959-
PrintNSHelper(Before, After, OS, NewDC);
6960-
6961-
if (NS)
6962-
After(OS, NS);
6963-
}
6964-
6965-
static void PrintNamespaces(raw_ostream &OS, const DeclContext *DC) {
6966-
PrintNSHelper([](raw_ostream &OS, const NamespaceDecl *NS) {},
6967-
[](raw_ostream &OS, const NamespaceDecl *NS) {
6968-
if (NS->isInline())
6969-
OS << "inline ";
6970-
OS << "namespace ";
6971-
if (!NS->isAnonymousNamespace())
6972-
OS << NS->getName() << " ";
6973-
OS << "{\n";
6974-
},
6975-
OS, DC);
6976-
}
6977-
6978-
static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
6979-
PrintNSHelper(
6980-
[](raw_ostream &OS, const NamespaceDecl *NS) {
6981-
OS << "} // ";
6982-
if (NS->isInline())
6983-
OS << "inline ";
6984-
6985-
OS << "namespace ";
6986-
if (!NS->isAnonymousNamespace())
6987-
OS << NS->getName();
6988-
6989-
OS << '\n';
6990-
},
6991-
[](raw_ostream &OS, const NamespaceDecl *NS) {}, OS, DC);
6992-
}
6993-
69947058
static std::string EmitShim(raw_ostream &OS, unsigned &ShimCounter,
69957059
const std::string &LastShim,
69967060
const NamespaceDecl *AnonNS) {

clang/test/CodeGenSYCL/free_function_default_template_arguments.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ foo(Arg1<int> arg) {
8686
// CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg;
8787
// CHECK-NEXT: }
8888

89-
// CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>);
90-
// CHECK-NEXT: static constexpr auto __sycl_shim1() {
91-
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple;
89+
// CHECK: namespace ns {
90+
// CHECK-NEXT: void simple(ns::Arg<char, int, 12, ns::notatuple> );
91+
// CHECK-NEXT: } // namespace ns
92+
// CHECK: static constexpr auto __sycl_shim1() {
93+
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))ns::simple;
9294
// CHECK-NEXT: }
9395

9496
// CHECK: Forward declarations of kernel and its argument types:
9597
// CHECK: namespace ns {
9698
// CHECK: namespace ns1 {
9799
// CHECK-NEXT: template <typename A> class hasDefaultArg;
98-
// CHECK-NEXT: }
100+
// CHECK-NEXT: }}
99101

100-
// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>);
102+
// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple> );
101103
// CHECK-NEXT: static constexpr auto __sycl_shim2() {
102104
// CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1;
103105
// CHECK-NEXT: }

0 commit comments

Comments
 (0)