From 6f0f1dc4d3bf28d7da4f5df4741631fa14db8887 Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 9 Jul 2025 10:25:59 +0000 Subject: [PATCH 1/2] [mlir][TableGen][NFC] Emit interface traits after all interfaces Interface traits may provide default implementation of methods. When this happens, the implementation may rely on another interface that is not yet defined meaning that one gets "incomplete type" error during C++ compilation. In pseudo-code, the problem is the following: ``` InterfaceA has methodB() { return InterfaceB(); } InterfaceB defined later // What's generated is: class InterfaceA { ... } class InterfaceATrait { // error: InterfaceB is an incomplete type InterfaceB methodB() { return InterfaceB(); } } class InterfaceB { ... } // defined here ``` The two more "advanced" cases are: * Cyclic dependency (A requires B and B requires A) * Type-traited usage of an incomplete type (e.g. `FailureOr`) It seems reasonable to emit interface traits *after* all of the interfaces have been defined to avoid the problem altogether. As a drive by, make forward declarations of the interfaces early so that user code does not need to forward declare. --- mlir/test/lib/Dialect/Test/TestInterfaces.td | 28 +++++++++++++ mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 44 +++++++++++++++----- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td index dea26b8dda62a..d3d96ea5a65a4 100644 --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -174,4 +174,32 @@ def TestOptionallyImplementedTypeInterface }]; } +// Dummy type interface "A" that requires type interface "B" to be complete. +def TestCyclicTypeInterfaceA : TypeInterface<"TestCyclicTypeInterfaceA"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceB>", + /*methodName=*/"returnB", + (ins), + /*methodBody=*/"", + /*defaultImpl=*/"return mlir::failure();" + >, + ]; +} + +// Dummy type interface "B" that requires type interface "A" to be complete. +def TestCyclicTypeInterfaceB : TypeInterface<"TestCyclicTypeInterfaceB"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceA>", + /*methodName=*/"returnA", + (ins), + /*methodBody=*/"", + /*defaultImpl=*/"return mlir::failure();" + >, + ]; +} + #endif // MLIR_TEST_DIALECT_TEST_INTERFACES diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 4dfa1908b3267..ba1396e7d12be 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -96,9 +96,9 @@ class InterfaceGenerator { void emitConceptDecl(const Interface &interface); void emitModelDecl(const Interface &interface); void emitModelMethodsDef(const Interface &interface); - void emitTraitDecl(const Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName); + void forwardDeclareInterface(const Interface &interface); void emitInterfaceDecl(const Interface &interface); + void emitInterfaceTraitDecl(const Interface &interface); /// The set of interface records to emit. std::vector defs; @@ -445,9 +445,16 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { os << "} // namespace " << ns << "\n"; } -void InterfaceGenerator::emitTraitDecl(const Interface &interface, - StringRef interfaceName, - StringRef interfaceTraitsName) { +void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + + os << "namespace detail {\n"; + + StringRef interfaceName = interface.getName(); + auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); os << llvm::formatv(" template \n" " struct {0}Trait : public ::mlir::{2}<{0}," " detail::{1}>::Trait<{3}> {{\n", @@ -494,6 +501,10 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface, os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; + os << "}// namespace detail\n"; + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; } static void emitInterfaceDeclMethods(const Interface &interface, @@ -517,6 +528,19 @@ static void emitInterfaceDeclMethods(const Interface &interface, os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n"; } +void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + + StringRef interfaceName = interface.getName(); + os << "class " << interfaceName << ";\n"; + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); @@ -533,7 +557,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { if (!comments.empty()) { os << comments << "\n"; } - os << "class " << interfaceName << ";\n"; // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" @@ -603,10 +626,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { os << "};\n"; - os << "namespace detail {\n"; - emitTraitDecl(interface, interfaceName, interfaceTraitsName); - os << "}// namespace detail\n"; - for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; } @@ -619,10 +638,15 @@ bool InterfaceGenerator::emitInterfaceDecls() { llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { return lhs->getID() < rhs->getID(); }); + for (const Record *def : sortedDefs) + forwardDeclareInterface(Interface(def)); for (const Record *def : sortedDefs) emitInterfaceDecl(Interface(def)); for (const Record *def : sortedDefs) emitModelMethodsDef(Interface(def)); + for (const Record *def : sortedDefs) + emitInterfaceTraitDecl(Interface(def)); + return false; } From db672eda80fa4b34740f99b04440cf58a0d858b6 Mon Sep 17 00:00:00 2001 From: "Golubev, Andrey" Date: Wed, 9 Jul 2025 13:16:03 +0000 Subject: [PATCH 2/2] Adapt tests to new changes --- mlir/test/mlir-tblgen/op-interface.td | 17 +++++++++-------- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 12 ++++++++++-- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td index 17bd631fe250d..aa71baddf58cd 100644 --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -31,11 +31,6 @@ def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> { // DECL-NEXT: return (*this).someOtherMethod(); // DECL-NEXT: } -// DECL: struct ExtraShardDeclsInterfaceTrait -// DECL: bool sharedMethodDeclaration() { -// DECL-NEXT: return (*static_cast(this)).someOtherMethod(); -// DECL-NEXT: } - def TestInheritanceMultiBaseInterface : OpInterface<"TestInheritanceMultiBaseInterface"> { let methods = [ InterfaceMethod< @@ -71,7 +66,7 @@ def TestInheritanceMiddleBaseInterface def TestInheritanceZDerivedInterface : OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>; -// DECL: class TestInheritanceZDerivedInterface +// DECL: struct TestInheritanceZDerivedInterfaceInterfaceTraits // DECL: struct Concept { // DECL: const TestInheritanceMultiBaseInterface::Concept *implTestInheritanceMultiBaseInterface = nullptr; // DECL-NOT: const TestInheritanceMultiBaseInterface::Concept @@ -173,10 +168,16 @@ def DeclareMethodsWithDefaultOp : Op(this)).someOtherMethod(); +// DECL-NEXT: } + +// DECL-LABEL: struct TestOpInterfaceVerifyTrait : public // DECL: verifyTrait -// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait +// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait : public // DECL: verifyRegionTrait // Method implementations come last, after all class definitions. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index ba1396e7d12be..3cc1636ac3317 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -534,6 +534,14 @@ void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { for (StringRef ns : namespaces) os << "namespace " << ns << " {\n"; + // Emit a forward declaration of the interface class so that it becomes usable + // in the signature of its methods. + std::string comments = tblgen::emitSummaryAndDescComments( + "", interface.getDescription().value_or("")); + if (!comments.empty()) { + os << comments << "\n"; + } + StringRef interfaceName = interface.getName(); os << "class " << interfaceName << ";\n"; @@ -642,10 +650,10 @@ bool InterfaceGenerator::emitInterfaceDecls() { forwardDeclareInterface(Interface(def)); for (const Record *def : sortedDefs) emitInterfaceDecl(Interface(def)); - for (const Record *def : sortedDefs) - emitModelMethodsDef(Interface(def)); for (const Record *def : sortedDefs) emitInterfaceTraitDecl(Interface(def)); + for (const Record *def : sortedDefs) + emitModelMethodsDef(Interface(def)); return false; }