Skip to content

Commit 44354f0

Browse files
authored
[NFC][SYCL] Use visitors to print kernel name type (#2660)
This is the part of integration header generator refactoring to make it use clang visitors and simplify the code.
1 parent ac7255d commit 44354f0

File tree

1 file changed

+95
-97
lines changed

1 file changed

+95
-97
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 95 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -3562,122 +3562,118 @@ static void emitCPPTypeString(raw_ostream &OS, QualType Ty) {
35623562
emitWithoutAnonNamespaces(OS, Ty.getAsString(P));
35633563
}
35643564

3565-
static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
3566-
ArrayRef<TemplateArgument> Args,
3567-
const PrintingPolicy &P);
3568-
3569-
static void emitKernelNameType(QualType T, ASTContext &Ctx, raw_ostream &OS,
3570-
const PrintingPolicy &TypePolicy);
3571-
3572-
static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
3573-
TemplateArgument Arg, const PrintingPolicy &P) {
3574-
switch (Arg.getKind()) {
3575-
case TemplateArgument::ArgKind::Pack: {
3576-
printArguments(Ctx, ArgOS, Arg.getPackAsArray(), P);
3577-
break;
3578-
}
3579-
case TemplateArgument::ArgKind::Integral: {
3580-
QualType T = Arg.getIntegralType();
3581-
const EnumType *ET = T->getAs<EnumType>();
3582-
3583-
if (ET) {
3584-
const llvm::APSInt &Val = Arg.getAsIntegral();
3585-
ArgOS << "static_cast<"
3586-
<< ET->getDecl()->getQualifiedNameAsString(
3587-
/*WithGlobalNsPrefix*/ true)
3588-
<< ">"
3589-
<< "(" << Val << ")";
3590-
} else {
3591-
Arg.print(P, ArgOS);
3565+
class SYCLKernelNameTypePrinter
3566+
: public TypeVisitor<SYCLKernelNameTypePrinter>,
3567+
public ConstTemplateArgumentVisitor<SYCLKernelNameTypePrinter> {
3568+
using InnerTypeVisitor = TypeVisitor<SYCLKernelNameTypePrinter>;
3569+
using InnerTemplArgVisitor =
3570+
ConstTemplateArgumentVisitor<SYCLKernelNameTypePrinter>;
3571+
raw_ostream &OS;
3572+
PrintingPolicy &Policy;
3573+
3574+
void printTemplateArgs(ArrayRef<TemplateArgument> Args) {
3575+
for (size_t I = 0, E = Args.size(); I < E; ++I) {
3576+
const TemplateArgument &Arg = Args[I];
3577+
// If argument is an empty pack argument, skip printing comma and
3578+
// argument.
3579+
if (Arg.getKind() == TemplateArgument::ArgKind::Pack && !Arg.pack_size())
3580+
continue;
3581+
3582+
if (I)
3583+
OS << ", ";
3584+
3585+
Visit(Arg);
35923586
}
3593-
break;
35943587
}
3595-
case TemplateArgument::ArgKind::Type: {
3596-
LangOptions LO;
3597-
PrintingPolicy TypePolicy(LO);
3598-
TypePolicy.SuppressTypedefs = true;
3599-
TypePolicy.SuppressTagKeyword = true;
3600-
QualType T = Arg.getAsType();
36013588

3602-
emitKernelNameType(T, Ctx, ArgOS, TypePolicy);
3603-
break;
3604-
}
3605-
case TemplateArgument::ArgKind::Template: {
3606-
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
3607-
ArgOS << TD->getQualifiedNameAsString();
3608-
break;
3589+
void VisitQualifiers(Qualifiers Quals) {
3590+
Quals.print(OS, Policy, /*appendSpaceIfNotEmpty*/ true);
36093591
}
3610-
default:
3611-
Arg.print(P, ArgOS);
3612-
}
3613-
}
36143592

3615-
static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
3616-
ArrayRef<TemplateArgument> Args,
3617-
const PrintingPolicy &P) {
3618-
for (unsigned I = 0; I < Args.size(); I++) {
3619-
const TemplateArgument &Arg = Args[I];
3593+
public:
3594+
SYCLKernelNameTypePrinter(raw_ostream &OS, PrintingPolicy &Policy)
3595+
: OS(OS), Policy(Policy) {}
36203596

3621-
// If argument is an empty pack argument, skip printing comma and argument.
3622-
if (Arg.getKind() == TemplateArgument::ArgKind::Pack && !Arg.pack_size())
3623-
continue;
3597+
void Visit(QualType T) {
3598+
if (T.isNull())
3599+
return;
36243600

3625-
if (I != 0)
3626-
ArgOS << ", ";
3601+
QualType CT = T.getCanonicalType();
3602+
VisitQualifiers(CT.getQualifiers());
36273603

3628-
printArgument(Ctx, ArgOS, Arg, P);
3604+
InnerTypeVisitor::Visit(CT.getTypePtr());
36293605
}
3630-
}
36313606

3632-
static void printTemplateArguments(ASTContext &Ctx, raw_ostream &ArgOS,
3633-
ArrayRef<TemplateArgument> Args,
3634-
const PrintingPolicy &P) {
3635-
ArgOS << "<";
3636-
printArguments(Ctx, ArgOS, Args, P);
3637-
ArgOS << ">";
3638-
}
3607+
void VisitType(const Type *T) {
3608+
OS << QualType::getAsString(T, Qualifiers(), Policy);
3609+
}
36393610

3640-
static void emitRecordType(raw_ostream &OS, QualType T, const CXXRecordDecl *RD,
3641-
const PrintingPolicy &TypePolicy) {
3642-
SmallString<64> Buf;
3643-
llvm::raw_svector_ostream RecOS(Buf);
3644-
T.getCanonicalType().getQualifiers().print(RecOS, TypePolicy,
3645-
/*appendSpaceIfNotEmpty*/ true);
3646-
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3611+
void Visit(const TemplateArgument &TA) {
3612+
if (TA.isNull())
3613+
return;
3614+
InnerTemplArgVisitor::Visit(TA);
3615+
}
36473616

3648-
// Print template class name
3649-
TSD->printQualifiedName(RecOS, TypePolicy, /*WithGlobalNsPrefix*/ true);
3617+
void VisitTagType(const TagType *T) {
3618+
TagDecl *RD = T->getDecl();
3619+
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
36503620

3651-
// Print template arguments substituting enumerators
3652-
ASTContext &Ctx = RD->getASTContext();
3653-
const TemplateArgumentList &Args = TSD->getTemplateArgs();
3654-
printTemplateArguments(Ctx, RecOS, Args.asArray(), TypePolicy);
3621+
// Print template class name
3622+
TSD->printQualifiedName(OS, Policy, /*WithGlobalNsPrefix*/ true);
36553623

3656-
emitWithoutAnonNamespaces(OS, RecOS.str());
3657-
return;
3624+
ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs().asArray();
3625+
OS << "<";
3626+
printTemplateArgs(Args);
3627+
OS << ">";
3628+
3629+
return;
3630+
}
3631+
// TODO: Next part of code results in printing of "class" keyword before
3632+
// class name in case if kernel name doesn't belong to some namespace. It
3633+
// seems if we don't print it, the integration header still represents valid
3634+
// c++ code. Probably we don't need to print it at all.
3635+
if (RD->getDeclContext()->isFunctionOrMethod()) {
3636+
OS << QualType::getAsString(T, Qualifiers(), Policy);
3637+
return;
3638+
}
3639+
3640+
const NamespaceDecl *NS = dyn_cast<NamespaceDecl>(RD->getDeclContext());
3641+
RD->printQualifiedName(OS, Policy, !(NS && NS->isAnonymousNamespace()));
36583642
}
3659-
if (RD->getDeclContext()->isFunctionOrMethod()) {
3660-
emitWithoutAnonNamespaces(OS, T.getCanonicalType().getAsString(TypePolicy));
3661-
return;
3643+
3644+
void VisitTemplateArgument(const TemplateArgument &TA) {
3645+
TA.print(Policy, OS);
36623646
}
36633647

3664-
const NamespaceDecl *NS = dyn_cast<NamespaceDecl>(RD->getDeclContext());
3665-
RD->printQualifiedName(RecOS, TypePolicy,
3666-
!(NS && NS->isAnonymousNamespace()));
3667-
emitWithoutAnonNamespaces(OS, RecOS.str());
3668-
}
3648+
void VisitTypeTemplateArgument(const TemplateArgument &TA) {
3649+
Policy.SuppressTagKeyword = true;
3650+
QualType T = TA.getAsType();
3651+
Visit(T);
3652+
Policy.SuppressTagKeyword = false;
3653+
}
36693654

3670-
static void emitKernelNameType(QualType T, ASTContext &Ctx, raw_ostream &OS,
3671-
const PrintingPolicy &TypePolicy) {
3672-
if (T->isRecordType()) {
3673-
emitRecordType(OS, T, T->getAsCXXRecordDecl(), TypePolicy);
3674-
return;
3655+
void VisitIntegralTemplateArgument(const TemplateArgument &TA) {
3656+
QualType T = TA.getIntegralType();
3657+
if (const EnumType *ET = T->getAs<EnumType>()) {
3658+
const llvm::APSInt &Val = TA.getAsIntegral();
3659+
OS << "static_cast<";
3660+
ET->getDecl()->printQualifiedName(OS, Policy,
3661+
/*WithGlobalNsPrefix*/ true);
3662+
OS << ">(" << Val << ")";
3663+
} else {
3664+
TA.print(Policy, OS);
3665+
}
36753666
}
36763667

3677-
if (T->isEnumeralType())
3678-
OS << "::";
3679-
emitWithoutAnonNamespaces(OS, T.getCanonicalType().getAsString(TypePolicy));
3680-
}
3668+
void VisitTemplateTemplateArgument(const TemplateArgument &TA) {
3669+
TemplateDecl *TD = TA.getAsTemplate().getAsTemplateDecl();
3670+
TD->printQualifiedName(OS, Policy);
3671+
}
3672+
3673+
void VisitPackTemplateArgument(const TemplateArgument &TA) {
3674+
printTemplateArgs(TA.getPackAsArray());
3675+
}
3676+
};
36813677

36823678
void SYCLIntegrationHeader::emit(raw_ostream &O) {
36833679
O << "// This is auto-generated SYCL integration header.\n";
@@ -3776,8 +3772,10 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37763772
LangOptions LO;
37773773
PrintingPolicy P(LO);
37783774
P.SuppressTypedefs = true;
3775+
P.SuppressUnwrittenScope = true;
37793776
O << "template <> struct KernelInfo<";
3780-
emitKernelNameType(K.NameType, S.getASTContext(), O, P);
3777+
SYCLKernelNameTypePrinter Printer(O, P);
3778+
Printer.Visit(K.NameType);
37813779
O << "> {\n";
37823780
}
37833781
O << " DLL_LOCAL\n";

0 commit comments

Comments
 (0)