Skip to content

Commit 08a2edc

Browse files
authored
[SYCL][clang] Emit default template arguments in integration header (#16005)
For free function kernels support clang forward declares the kernel itself as well as its parameter types. In case a free function kernel has a parameter that is templated and has a default template argument, all template arguments including arguments that match default arguments must be printed in kernel's forward declarations, for example ``` template <typename T, typename = int> struct Arg { T val; }; // For the kernel SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( (ext::oneapi::experimental::nd_range_kernel<1>)) void foo(Arg<int> arg) { arg.val = 42; } // Integration header must contain void foo(Arg<int, int> arg); ``` Unfortunately, even though integration header emission already has extensive support for forward declarations priting, some modifications to clang's type printing are still required, since neither of existing PrintingPolicy flags help to reach the correct result. Using `SuppressDefaultTemplateArgs = true` doesn't help without printing canonical types, printing canonical types for the case like ``` template <typename T> SYCL_EXT_ONEAPI_FUNCTION_PROPERTY( (ext::oneapi::experimental::nd_range_kernel<1>)) void foo(Arg<T> arg) { arg.val = 42; } // Printing canonical types is causing the following integration header template <typename T> void foo(Arg<type-parameter-0-0, int> arg); ``` Using `SkipCanonicalizationOfTemplateTypeParms` field of printing policy doesn't help here since at the one point where it is checked we take canonical type of `Arg`, not its parameters and it will contain template argument types in canonical type after that.
1 parent 01f7e44 commit 08a2edc

File tree

4 files changed

+185
-27
lines changed

4 files changed

+185
-27
lines changed

clang/include/clang/AST/PrettyPrinter.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,18 @@ struct PrintingPolicy {
6868
SuppressStrongLifetime(false), SuppressLifetimeQualifiers(false),
6969
SuppressTypedefs(false), SuppressFinalSpecifier(false),
7070
SuppressTemplateArgsInCXXConstructors(false),
71-
SuppressDefaultTemplateArgs(true), Bool(LO.Bool),
72-
Nullptr(LO.CPlusPlus11 || LO.C23), NullptrTypeInNamespace(LO.CPlusPlus),
73-
Restrict(LO.C99), Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
71+
SuppressDefaultTemplateArgs(true), EnforceDefaultTemplateArgs(false),
72+
Bool(LO.Bool), Nullptr(LO.CPlusPlus11 || LO.C23),
73+
NullptrTypeInNamespace(LO.CPlusPlus), Restrict(LO.C99),
74+
Alignof(LO.CPlusPlus11), UnderscoreAlignof(LO.C11),
7475
UseVoidForZeroParams(!LO.CPlusPlus),
7576
SplitTemplateClosers(!LO.CPlusPlus11), TerseOutput(false),
7677
PolishForDeclaration(false), Half(LO.Half),
7778
MSWChar(LO.MicrosoftExt && !LO.WChar), IncludeNewlines(true),
7879
MSVCFormatting(false), ConstantsAsWritten(false),
7980
SuppressImplicitBase(false), FullyQualifiedName(false),
80-
SuppressDefinition(false), SuppressDefaultTemplateArguments(false),
81-
PrintCanonicalTypes(false),
81+
EnforceScopeForElaboratedTypes(false), SuppressDefinition(false),
82+
SuppressDefaultTemplateArguments(false), PrintCanonicalTypes(false),
8283
SkipCanonicalizationOfTemplateTypeParms(false),
8384
PrintInjectedClassNameWithArguments(true), UsePreferredNames(true),
8485
AlwaysIncludeTypeForTemplateArgument(false),
@@ -241,6 +242,11 @@ struct PrintingPolicy {
241242
LLVM_PREFERRED_TYPE(bool)
242243
unsigned SuppressDefaultTemplateArgs : 1;
243244

245+
/// When true, print template arguments that match the default argument for
246+
/// the parameter, even if they're not specified in the source.
247+
LLVM_PREFERRED_TYPE(bool)
248+
unsigned EnforceDefaultTemplateArgs : 1;
249+
244250
/// Whether we can use 'bool' rather than '_Bool' (even if the language
245251
/// doesn't actually have 'bool', because, e.g., it is defined as a macro).
246252
LLVM_PREFERRED_TYPE(bool)
@@ -339,6 +345,10 @@ struct PrintingPolicy {
339345
LLVM_PREFERRED_TYPE(bool)
340346
unsigned FullyQualifiedName : 1;
341347

348+
/// Enforce fully qualified name printing for elaborated types.
349+
LLVM_PREFERRED_TYPE(bool)
350+
unsigned EnforceScopeForElaboratedTypes : 1;
351+
342352
/// When true does not print definition of a type. E.g.
343353
/// \code
344354
/// template<typename T> class C0 : public C1 {...}

clang/lib/AST/TypePrinter.cpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ElaboratedTypePolicyRAII {
101101
SuppressTagKeyword = Policy.SuppressTagKeyword;
102102
SuppressScope = Policy.SuppressScope;
103103
Policy.SuppressTagKeyword = true;
104-
Policy.SuppressScope = true;
104+
Policy.SuppressScope = !Policy.EnforceScopeForElaboratedTypes;
105105
}
106106

107107
~ElaboratedTypePolicyRAII() {
@@ -1728,8 +1728,10 @@ void TypePrinter::printElaboratedBefore(const ElaboratedType *T,
17281728
Policy.SuppressScope = OldSupressScope;
17291729
return;
17301730
}
1731-
if (Qualifier && !(Policy.SuppressTypedefs &&
1732-
T->getNamedType()->getTypeClass() == Type::Typedef))
1731+
if (Qualifier &&
1732+
!(Policy.SuppressTypedefs &&
1733+
T->getNamedType()->getTypeClass() == Type::Typedef) &&
1734+
!Policy.EnforceScopeForElaboratedTypes)
17331735
Qualifier->print(OS, Policy);
17341736
}
17351737

@@ -2220,15 +2222,6 @@ static void printArgument(const TemplateArgument &A, const PrintingPolicy &PP,
22202222
A.print(PP, OS, IncludeType);
22212223
}
22222224

2223-
static void printArgument(const TemplateArgumentLoc &A,
2224-
const PrintingPolicy &PP, llvm::raw_ostream &OS,
2225-
bool IncludeType) {
2226-
const TemplateArgument::ArgKind &Kind = A.getArgument().getKind();
2227-
if (Kind == TemplateArgument::ArgKind::Type)
2228-
return A.getTypeSourceInfo()->getType().print(OS, PP);
2229-
return A.getArgument().print(PP, OS, IncludeType);
2230-
}
2231-
22322225
static bool isSubstitutedTemplateArgument(ASTContext &Ctx, TemplateArgument Arg,
22332226
TemplateArgument Pattern,
22342227
ArrayRef<TemplateArgument> Args,
@@ -2399,15 +2392,40 @@ template <typename TA>
23992392
static void
24002393
printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,
24012394
const TemplateParameterList *TPL, bool IsPack, unsigned ParmIndex) {
2402-
// Drop trailing template arguments that match default arguments.
2403-
if (TPL && Policy.SuppressDefaultTemplateArgs &&
2404-
!Policy.PrintCanonicalTypes && !Args.empty() && !IsPack &&
2395+
llvm::SmallVector<TemplateArgument, 8> ArgsToPrint;
2396+
for (const TA &A : Args)
2397+
ArgsToPrint.push_back(getArgument(A));
2398+
if (TPL && !Policy.PrintCanonicalTypes && !IsPack &&
24052399
Args.size() <= TPL->size()) {
2406-
llvm::SmallVector<TemplateArgument, 8> OrigArgs;
2407-
for (const TA &A : Args)
2408-
OrigArgs.push_back(getArgument(A));
2409-
while (!Args.empty() && getArgument(Args.back()).getIsDefaulted())
2410-
Args = Args.drop_back();
2400+
// Drop trailing template arguments that match default arguments.
2401+
if (Policy.SuppressDefaultTemplateArgs) {
2402+
while (!ArgsToPrint.empty() &&
2403+
getArgument(ArgsToPrint.back()).getIsDefaulted())
2404+
ArgsToPrint.pop_back();
2405+
} else if (Policy.EnforceDefaultTemplateArgs) {
2406+
for (unsigned I = Args.size(); I < TPL->size(); ++I) {
2407+
auto Param = TPL->getParam(I);
2408+
if (auto *TTPD = dyn_cast<TemplateTypeParmDecl>(Param)) {
2409+
// If we met a non default-argument past provided list of arguments,
2410+
// it is either a pack which must be the last arguments, or provided
2411+
// argument list was problematic. Bail out either way. Do the same
2412+
// for each kind of template argument.
2413+
if (!TTPD->hasDefaultArgument())
2414+
break;
2415+
ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument()));
2416+
} else if (auto *TTPD = dyn_cast<TemplateTemplateParmDecl>(Param)) {
2417+
if (!TTPD->hasDefaultArgument())
2418+
break;
2419+
ArgsToPrint.push_back(getArgument(TTPD->getDefaultArgument()));
2420+
} else if (auto *NTTPD = dyn_cast<NonTypeTemplateParmDecl>(Param)) {
2421+
if (!NTTPD->hasDefaultArgument())
2422+
break;
2423+
ArgsToPrint.push_back(getArgument(NTTPD->getDefaultArgument()));
2424+
} else {
2425+
llvm_unreachable("unexpected template parameter");
2426+
}
2427+
}
2428+
}
24112429
}
24122430

24132431
const char *Comma = Policy.MSVCFormatting ? "," : ", ";
@@ -2416,7 +2434,7 @@ printTo(raw_ostream &OS, ArrayRef<TA> Args, const PrintingPolicy &Policy,
24162434

24172435
bool NeedSpace = false;
24182436
bool FirstArg = true;
2419-
for (const auto &Arg : Args) {
2437+
for (const auto &Arg : ArgsToPrint) {
24202438
// Print the argument into a string.
24212439
SmallString<128> Buf;
24222440
llvm::raw_svector_ostream ArgOS(Buf);

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6509,16 +6509,46 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
65096509
O << "extern \"C\" ";
65106510
std::string ParmList;
65116511
bool FirstParam = true;
6512+
Policy.SuppressDefaultTemplateArgs = false;
65126513
for (ParmVarDecl *Param : K.SyclKernel->parameters()) {
65136514
if (FirstParam)
65146515
FirstParam = false;
65156516
else
65166517
ParmList += ", ";
6517-
ParmList += Param->getType().getCanonicalType().getAsString();
6518+
ParmList += Param->getType().getCanonicalType().getAsString(Policy);
65186519
}
65196520
FunctionTemplateDecl *FTD = K.SyclKernel->getPrimaryTemplate();
65206521
Policy.SuppressDefinition = true;
65216522
Policy.PolishForDeclaration = true;
6523+
Policy.FullyQualifiedName = true;
6524+
Policy.EnforceScopeForElaboratedTypes = true;
6525+
6526+
// Now we need to print the declaration of the kernel itself.
6527+
// Example:
6528+
// template <typename T, typename = int> struct Arg {
6529+
// T val;
6530+
// };
6531+
// For the following free function kernel:
6532+
// template <typename = T>
6533+
// SYCL_EXT_ONEAPI_FUNCTION_PROPERTY(
6534+
// (ext::oneapi::experimental::nd_range_kernel<1>))
6535+
// void foo(Arg<int> arg) {}
6536+
// Integration header must contain the following declaration:
6537+
// template <typename>
6538+
// void foo(Arg<int, int> arg);
6539+
// SuppressDefaultTemplateArguments is a downstream addition that suppresses
6540+
// default template arguments in the function declaration. It should be set
6541+
// to true to emit function declaration that won't cause any compilation
6542+
// errors when present in the integration header.
6543+
// To print Arg<int, int> in the function declaration and shim functions we
6544+
// need to disable default arguments printing suppression via community flag
6545+
// SuppressDefaultTemplateArgs, otherwise they will be suppressed even for
6546+
// canonical types or if even written in the original source code.
6547+
Policy.SuppressDefaultTemplateArguments = true;
6548+
// EnforceDefaultTemplateArgs is a downstream addition that forces printing
6549+
// template arguments that match default template arguments while printing
6550+
// template-ids, even if the source code doesn't reference them.
6551+
Policy.EnforceDefaultTemplateArgs = true;
65226552
if (FTD) {
65236553
FTD->print(O, Policy);
65246554
} else {
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -triple spir64-unknown-unknown -sycl-std=2020 -fsycl-int-header=%t.h %s
2+
// RUN: FileCheck -input-file=%t.h %s
3+
4+
// This test checks integration header contents for free functions kernels with
5+
// parameter types that have default template arguments.
6+
7+
#include "mock_properties.hpp"
8+
#include "sycl.hpp"
9+
10+
namespace ns {
11+
12+
struct notatuple {
13+
int a;
14+
};
15+
16+
namespace ns1 {
17+
template <typename A = notatuple>
18+
class hasDefaultArg {
19+
20+
};
21+
}
22+
23+
template <typename T, typename = int, int a = 12, typename = notatuple, typename ...TS> struct Arg {
24+
T val;
25+
};
26+
27+
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel",
28+
2)]] void
29+
simple(Arg<char>){
30+
}
31+
32+
}
33+
34+
[[__sycl_detail__::add_ir_attributes_function("sycl-single-task-kernel",
35+
2)]] void
36+
simple1(ns::Arg<ns::ns1::hasDefaultArg<>>){
37+
}
38+
39+
40+
template <typename T>
41+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
42+
templated(ns::Arg<T, float, 3>, T end) {
43+
}
44+
45+
template void templated(ns::Arg<int, float, 3>, int);
46+
47+
using namespace ns;
48+
49+
template <typename T>
50+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
51+
templated2(Arg<T, notatuple>, T end) {
52+
}
53+
54+
template void templated2(Arg<int, notatuple>, int);
55+
56+
template <typename T, int a = 3>
57+
[[__sycl_detail__::add_ir_attributes_function("sycl-nd-range-kernel", 2)]] void
58+
templated3(Arg<T, notatuple, a, ns1::hasDefaultArg<>, int, int>, T end) {
59+
}
60+
61+
template void templated3(Arg<int, notatuple, 3, ns1::hasDefaultArg<>, int, int>, int);
62+
63+
// CHECK: Forward declarations of kernel and its argument types:
64+
// CHECK-NEXT: namespace ns {
65+
// CHECK-NEXT: struct notatuple;
66+
// CHECK-NEXT: }
67+
// CHECK-NEXT: namespace ns {
68+
// CHECK-NEXT: template <typename T, typename, int a, typename, typename ...TS> struct Arg;
69+
// CHECK-NEXT: }
70+
71+
// CHECK: void ns::simple(ns::Arg<char, int, 12, ns::notatuple>);
72+
// CHECK-NEXT: static constexpr auto __sycl_shim1() {
73+
// CHECK-NEXT: return (void (*)(struct ns::Arg<char, int, 12, struct ns::notatuple>))simple;
74+
// CHECK-NEXT: }
75+
76+
// CHECK: Forward declarations of kernel and its argument types:
77+
// CHECK: namespace ns {
78+
// CHECK: namespace ns1 {
79+
// CHECK-NEXT: template <typename A> class hasDefaultArg;
80+
// CHECK-NEXT: }
81+
82+
// CHECK: void simple1(ns::Arg<ns::ns1::hasDefaultArg<ns::notatuple>, int, 12, ns::notatuple>);
83+
// CHECK-NEXT: static constexpr auto __sycl_shim2() {
84+
// CHECK-NEXT: return (void (*)(struct ns::Arg<class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, 12, struct ns::notatuple>))simple1;
85+
// CHECK-NEXT: }
86+
87+
// CHECK: template <typename T> void templated(ns::Arg<T, float, 3, ns::notatuple>, T end);
88+
// CHECK-NEXT: static constexpr auto __sycl_shim3() {
89+
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, float, 3, struct ns::notatuple>, int))templated<int>;
90+
// CHECK-NEXT: }
91+
92+
// CHECK: template <typename T> void templated2(ns::Arg<T, ns::notatuple, 12, ns::notatuple>, T end);
93+
// CHECK-NEXT: static constexpr auto __sycl_shim4() {
94+
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 12, struct ns::notatuple>, int))templated2<int>;
95+
// CHECK-NEXT: }
96+
97+
// CHECK: template <typename T, int a> void templated3(ns::Arg<T, ns::notatuple, a, ns::ns1::hasDefaultArg<ns::notatuple>, int, int>, T end);
98+
// CHECK-NEXT: static constexpr auto __sycl_shim5() {
99+
// CHECK-NEXT: return (void (*)(struct ns::Arg<int, struct ns::notatuple, 3, class ns::ns1::hasDefaultArg<struct ns::notatuple>, int, int>, int))templated3<int, 3>;
100+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)