Skip to content

Commit 8c62c24

Browse files
[clang] Add support for SYCL templated free functions declared in namespaces (#17936)
This PR adds support of templated free function declared both in namespace or not. Support of accessor types should be added in a separate PR because it requires effort not only on the frontend side. Docs: https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc --------- Co-authored-by: Mariya Podchishchaeva <mariya.podchishchaeva@intel.com>
1 parent 22ed675 commit 8c62c24

File tree

11 files changed

+1171
-36
lines changed

11 files changed

+1171
-36
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12670,6 +12670,8 @@ def err_registered_kernels_name_already_registered : Error<
1267012670
"free function kernel has already been registered with '%0'; cannot register with '%1'">;
1267112671
def err_not_sycl_free_function : Error<
1267212672
"attempting to register a function that is not a SYCL free function as '%0'">;
12673+
def err_free_function_variadic_args: Error<
12674+
"free function kernel cannot be a variadic function">;
1267312675

1267412676
// SYCL kernel entry point diagnostics
1267512677
def err_sycl_entry_point_invalid : Error<

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 128 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5793,6 +5793,10 @@ void SemaSYCL::MarkDevices() {
57935793

57945794
void SemaSYCL::ProcessFreeFunction(FunctionDecl *FD) {
57955795
if (isFreeFunction(FD)) {
5796+
if (FD->isVariadic()) {
5797+
Diag(FD->getLocation(), diag::err_free_function_variadic_args);
5798+
return;
5799+
}
57965800
SyclKernelDecompMarker DecompMarker(*this);
57975801
SyclKernelFieldChecker FieldChecker(*this);
57985802
SyclKernelUnionChecker UnionChecker(*this);
@@ -6491,16 +6495,36 @@ static void PrintNSClosingBraces(raw_ostream &OS, const DeclContext *DC) {
64916495

64926496
class FreeFunctionPrinter {
64936497
raw_ostream &O;
6498+
PrintingPolicy &Policy;
64946499
bool NSInserted = false;
64956500

64966501
public:
6497-
FreeFunctionPrinter(raw_ostream &O) : O(O) {}
6502+
FreeFunctionPrinter(raw_ostream &O, PrintingPolicy &PrintPolicy)
6503+
: O(O), Policy(PrintPolicy) {}
6504+
6505+
/// Emits the function declaration of template free function.
6506+
/// \param FTD The function declaration to print.
6507+
/// \param S Sema object.
6508+
void printFreeFunctionDeclaration(FunctionTemplateDecl *FTD,
6509+
clang::SemaSYCL &S) {
6510+
const FunctionDecl *TemplatedDecl = FTD->getTemplatedDecl();
6511+
if (!TemplatedDecl)
6512+
return;
6513+
const std::string TemplatedDeclParams =
6514+
getTemplatedParamList(TemplatedDecl->parameters(), Policy);
6515+
const std::string TemplateParams =
6516+
getTemplateParameters(FTD->getTemplateParameters(), S);
6517+
printFreeFunctionDeclaration(TemplatedDecl, TemplatedDeclParams,
6518+
TemplateParams);
6519+
}
64986520

64996521
/// Emits the function declaration of a free function.
65006522
/// \param FD The function declaration to print.
65016523
/// \param Args The arguments of the function.
6524+
/// \param TemplateParameters The template parameters of the function.
65026525
void printFreeFunctionDeclaration(const FunctionDecl *FD,
6503-
const std::string &Args) {
6526+
const std::string &Args,
6527+
std::string_view TemplateParameters = "") {
65046528
const DeclContext *DC = FD->getDeclContext();
65056529
if (DC) {
65066530
// if function in namespace, print namespace
@@ -6510,6 +6534,7 @@ class FreeFunctionPrinter {
65106534
// function
65116535
NSInserted = true;
65126536
}
6537+
O << TemplateParameters;
65136538
O << FD->getReturnType().getAsString() << " ";
65146539
O << FD->getNameAsString() << "(" << Args << ");";
65156540
if (NSInserted) {
@@ -6533,6 +6558,95 @@ class FreeFunctionPrinter {
65336558
if (NSInserted)
65346559
PrintNamespaces(O, FD, /*isPrintNamesOnly=*/true);
65356560
O << FD->getIdentifier()->getName().data();
6561+
if (FD->getPrimaryTemplate()) {
6562+
std::string Buffer;
6563+
llvm::raw_string_ostream StringStream(Buffer);
6564+
const TemplateArgumentList *TAL = FD->getTemplateSpecializationArgs();
6565+
ArrayRef<TemplateArgument> A = TAL->asArray();
6566+
bool FirstParam = true;
6567+
for (const auto &X : A) {
6568+
if (FirstParam)
6569+
FirstParam = false;
6570+
else if (X.getKind() == TemplateArgument::Pack) {
6571+
for (const auto &PackArg : X.pack_elements()) {
6572+
StringStream << ", ";
6573+
PackArg.print(Policy, StringStream, true);
6574+
}
6575+
continue;
6576+
} else {
6577+
StringStream << ", ";
6578+
}
6579+
6580+
X.print(Policy, StringStream, true);
6581+
}
6582+
StringStream.flush();
6583+
if (Buffer.front() != '<')
6584+
Buffer = "<" + Buffer + ">";
6585+
O << Buffer;
6586+
}
6587+
}
6588+
6589+
private:
6590+
/// Helper method to get arguments of templated function as a string
6591+
/// \param Parameters Array of parameters of the function.
6592+
/// \param Policy Printing policy.
6593+
/// returned string Example:
6594+
/// \code
6595+
/// template <typename T1, typename T2>
6596+
/// void foo(T1 a, T2 b);
6597+
/// \endcode
6598+
/// returns string "T1 a, T2 b"
6599+
std::string
6600+
getTemplatedParamList(const llvm::ArrayRef<clang::ParmVarDecl *> Parameters,
6601+
PrintingPolicy Policy) {
6602+
bool FirstParam = true;
6603+
llvm::SmallString<128> ParamList;
6604+
llvm::raw_svector_ostream ParmListOstream{ParamList};
6605+
Policy.SuppressTagKeyword = true;
6606+
for (ParmVarDecl *Param : Parameters) {
6607+
if (FirstParam)
6608+
FirstParam = false;
6609+
else
6610+
ParmListOstream << ", ";
6611+
ParmListOstream << Param->getType().getAsString(Policy);
6612+
ParmListOstream << " " << Param->getNameAsString();
6613+
}
6614+
return ParamList.str().str();
6615+
}
6616+
6617+
/// Helper method to get text representation of the template parameters.
6618+
/// Throws an error if the last parameter is a pack.
6619+
/// \param TPL The template parameter list.
6620+
/// \param S The SemaSYCL object.
6621+
/// Example:
6622+
/// \code
6623+
/// template <typename T1, class T2>
6624+
/// void foo(T1 a, T2 b);
6625+
/// \endcode
6626+
/// returns string "template <typename T1, class T2> "
6627+
std::string getTemplateParameters(const clang::TemplateParameterList *TPL,
6628+
SemaSYCL &S) {
6629+
std::string TemplateParams{"template <"};
6630+
bool FirstParam{true};
6631+
for (NamedDecl *Param : *TPL) {
6632+
if (!FirstParam)
6633+
TemplateParams += ", ";
6634+
FirstParam = false;
6635+
if (const auto *TemplateParam = dyn_cast<TemplateTypeParmDecl>(Param)) {
6636+
TemplateParams +=
6637+
TemplateParam->wasDeclaredWithTypename() ? "typename " : "class ";
6638+
if (TemplateParam->isParameterPack())
6639+
TemplateParams += "... ";
6640+
TemplateParams += TemplateParam->getNameAsString();
6641+
} else if (const auto *NonTypeParam =
6642+
dyn_cast<NonTypeTemplateParmDecl>(Param)) {
6643+
TemplateParams += NonTypeParam->getType().getAsString();
6644+
TemplateParams += " ";
6645+
TemplateParams += NonTypeParam->getNameAsString();
6646+
}
6647+
}
6648+
TemplateParams += "> ";
6649+
return TemplateParams;
65366650
}
65376651
};
65386652

@@ -6836,11 +6950,16 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
68366950
ParmList += ", ";
68376951
ParmListWithNamesOstream << ", ";
68386952
}
6839-
Policy.SuppressTagKeyword = true;
6840-
Param->getType().print(ParmListWithNamesOstream, Policy);
6841-
Policy.SuppressTagKeyword = false;
6842-
ParmListWithNamesOstream << " " << Param->getNameAsString();
6843-
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
6953+
if (Param->isParameterPack()) {
6954+
ParmListWithNamesOstream << "Args... args";
6955+
ParmList += "Args ...";
6956+
} else {
6957+
Policy.SuppressTagKeyword = true;
6958+
Param->getType().print(ParmListWithNamesOstream, Policy);
6959+
Policy.SuppressTagKeyword = false;
6960+
ParmListWithNamesOstream << " " << Param->getNameAsString();
6961+
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
6962+
}
68446963
}
68456964
ParmListWithNamesOstream.flush();
68466965
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
@@ -6876,30 +6995,14 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
68766995
// template arguments that match default template arguments while printing
68776996
// template-ids, even if the source code doesn't reference them.
68786997
Policy.EnforceDefaultTemplateArgs = true;
6879-
FreeFunctionPrinter FFPrinter(O);
6998+
FreeFunctionPrinter FFPrinter(O, Policy);
68806999
if (FTD) {
6881-
FTD->print(O, Policy);
6882-
O << ";\n";
7000+
FFPrinter.printFreeFunctionDeclaration(FTD, S);
68837001
} else {
68847002
FFPrinter.printFreeFunctionDeclaration(K.SyclKernel, ParmListWithNames);
68857003
}
68867004

68877005
FFPrinter.printFreeFunctionShim(K.SyclKernel, ShimCounter, ParmList);
6888-
if (FTD) {
6889-
const TemplateArgumentList *TAL =
6890-
K.SyclKernel->getTemplateSpecializationArgs();
6891-
ArrayRef<TemplateArgument> A = TAL->asArray();
6892-
bool FirstParam = true;
6893-
O << "<";
6894-
for (const auto &X : A) {
6895-
if (FirstParam)
6896-
FirstParam = false;
6897-
else
6898-
O << ", ";
6899-
X.print(Policy, O, true);
6900-
}
6901-
O << ">";
6902-
}
69037006
O << ";\n";
69047007
O << "}\n";
69057008
Policy.SuppressDefaultTemplateArgs = true;

0 commit comments

Comments
 (0)