Skip to content

[mlir][TableGen] Emit interface traits after all interfaces #147699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlir/test/lib/Dialect/Test/TestInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,32 @@ def TestOptionallyImplementedTypeInterface
}];
}

// Dummy type interface "A" that requires type interface "B" to be complete.
def TestCyclicTypeInterfaceA : TypeInterface<"TestCyclicTypeInterfaceA"> {
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<"",
"::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceB>",
/*methodName=*/"returnB",
(ins),
/*methodBody=*/"",
/*defaultImpl=*/"return mlir::failure();"
>,
];
}

// Dummy type interface "B" that requires type interface "A" to be complete.
def TestCyclicTypeInterfaceB : TypeInterface<"TestCyclicTypeInterfaceB"> {
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<"",
"::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceA>",
/*methodName=*/"returnA",
(ins),
/*methodBody=*/"",
/*defaultImpl=*/"return mlir::failure();"
>,
];
}

#endif // MLIR_TEST_DIALECT_TEST_INTERFACES
17 changes: 9 additions & 8 deletions mlir/test/mlir-tblgen/op-interface.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
// DECL-NEXT: return (*this).someOtherMethod();
// DECL-NEXT: }

// DECL: struct ExtraShardDeclsInterfaceTrait
// DECL: bool sharedMethodDeclaration() {
// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
// DECL-NEXT: }

def TestInheritanceMultiBaseInterface : OpInterface<"TestInheritanceMultiBaseInterface"> {
let methods = [
InterfaceMethod<
Expand Down Expand Up @@ -71,7 +66,7 @@ def TestInheritanceMiddleBaseInterface
def TestInheritanceZDerivedInterface
: OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>;

// DECL: class TestInheritanceZDerivedInterface
// DECL: struct TestInheritanceZDerivedInterfaceInterfaceTraits
// DECL: struct Concept {
// DECL: const TestInheritanceMultiBaseInterface::Concept *implTestInheritanceMultiBaseInterface = nullptr;
// DECL-NOT: const TestInheritanceMultiBaseInterface::Concept
Expand Down Expand Up @@ -173,10 +168,16 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
// DECL: /// some function comment
// DECL: int foo(int input);

// DECL-LABEL: struct TestOpInterfaceVerifyTrait
// Trait declarations / definitions come after interface definitions.
// DECL: struct ExtraShardDeclsInterfaceTrait : public
// DECL: bool sharedMethodDeclaration() {
// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
// DECL-NEXT: }

// DECL-LABEL: struct TestOpInterfaceVerifyTrait : public
// DECL: verifyTrait

// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait : public
// DECL: verifyRegionTrait

// Method implementations come last, after all class definitions.
Expand Down
52 changes: 42 additions & 10 deletions mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class InterfaceGenerator {
void emitConceptDecl(const Interface &interface);
void emitModelDecl(const Interface &interface);
void emitModelMethodsDef(const Interface &interface);
void emitTraitDecl(const Interface &interface, StringRef interfaceName,
StringRef interfaceTraitsName);
void forwardDeclareInterface(const Interface &interface);
void emitInterfaceDecl(const Interface &interface);
void emitInterfaceTraitDecl(const Interface &interface);

/// The set of interface records to emit.
std::vector<const Record *> defs;
Expand Down Expand Up @@ -445,9 +445,16 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
os << "} // namespace " << ns << "\n";
}

void InterfaceGenerator::emitTraitDecl(const Interface &interface,
StringRef interfaceName,
StringRef interfaceTraitsName) {
void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n";

os << "namespace detail {\n";

StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
os << llvm::formatv(" template <typename {3}>\n"
" struct {0}Trait : public ::mlir::{2}<{0},"
" detail::{1}>::Trait<{3}> {{\n",
Expand Down Expand Up @@ -494,6 +501,10 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";

os << " };\n";
os << "}// namespace detail\n";

for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}

static void emitInterfaceDeclMethods(const Interface &interface,
Expand All @@ -517,6 +528,27 @@ static void emitInterfaceDeclMethods(const Interface &interface,
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
}

void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n";

// Emit a forward declaration of the interface class so that it becomes usable
// in the signature of its methods.
std::string comments = tblgen::emitSummaryAndDescComments(
"", interface.getDescription().value_or(""));
if (!comments.empty()) {
os << comments << "\n";
}

StringRef interfaceName = interface.getName();
os << "class " << interfaceName << ";\n";

for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}

void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
Expand All @@ -533,7 +565,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
if (!comments.empty()) {
os << comments << "\n";
}
os << "class " << interfaceName << ";\n";

// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
Expand Down Expand Up @@ -603,10 +634,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {

os << "};\n";

os << "namespace detail {\n";
emitTraitDecl(interface, interfaceName, interfaceTraitsName);
os << "}// namespace detail\n";

for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}
Expand All @@ -619,10 +646,15 @@ bool InterfaceGenerator::emitInterfaceDecls() {
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
});
for (const Record *def : sortedDefs)
forwardDeclareInterface(Interface(def));
for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def));
for (const Record *def : sortedDefs)
emitInterfaceTraitDecl(Interface(def));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi: Interface ctor can be rather expensive (it has a vector of base interfaces that it populates + vector of methods).
As we re-create the objects 4 times here, does it make sense to refactor? (I failed to do it immediately so it'd probably go into a separate patch).

FWIW tablegen is part of the build system so if it is slow, the whole build is slow.

for (const Record *def : sortedDefs)
emitModelMethodsDef(Interface(def));

return false;
}

Expand Down