diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td index dea26b8dda62a..d3d96ea5a65a4 100644 --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -174,4 +174,32 @@ def TestOptionallyImplementedTypeInterface }]; } +// Dummy type interface "A" that requires type interface "B" to be complete. +def TestCyclicTypeInterfaceA : TypeInterface<"TestCyclicTypeInterfaceA"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceB>", + /*methodName=*/"returnB", + (ins), + /*methodBody=*/"", + /*defaultImpl=*/"return mlir::failure();" + >, + ]; +} + +// Dummy type interface "B" that requires type interface "A" to be complete. +def TestCyclicTypeInterfaceB : TypeInterface<"TestCyclicTypeInterfaceB"> { + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod<"", + "::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceA>", + /*methodName=*/"returnA", + (ins), + /*methodBody=*/"", + /*defaultImpl=*/"return mlir::failure();" + >, + ]; +} + #endif // MLIR_TEST_DIALECT_TEST_INTERFACES diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td index 17bd631fe250d..aa71baddf58cd 100644 --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -31,11 +31,6 @@ def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> { // DECL-NEXT: return (*this).someOtherMethod(); // DECL-NEXT: } -// DECL: struct ExtraShardDeclsInterfaceTrait -// DECL: bool sharedMethodDeclaration() { -// DECL-NEXT: return (*static_cast(this)).someOtherMethod(); -// DECL-NEXT: } - def TestInheritanceMultiBaseInterface : OpInterface<"TestInheritanceMultiBaseInterface"> { let methods = [ InterfaceMethod< @@ -71,7 +66,7 @@ def TestInheritanceMiddleBaseInterface def TestInheritanceZDerivedInterface : OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>; -// DECL: class TestInheritanceZDerivedInterface +// DECL: struct TestInheritanceZDerivedInterfaceInterfaceTraits // DECL: struct Concept { // DECL: const TestInheritanceMultiBaseInterface::Concept *implTestInheritanceMultiBaseInterface = nullptr; // DECL-NOT: const TestInheritanceMultiBaseInterface::Concept @@ -173,10 +168,16 @@ def DeclareMethodsWithDefaultOp : Op(this)).someOtherMethod(); +// DECL-NEXT: } + +// DECL-LABEL: struct TestOpInterfaceVerifyTrait : public // DECL: verifyTrait -// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait +// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait : public // DECL: verifyRegionTrait // Method implementations come last, after all class definitions. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 4dfa1908b3267..3cc1636ac3317 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -96,9 +96,9 @@ class InterfaceGenerator { void emitConceptDecl(const Interface &interface); void emitModelDecl(const Interface &interface); void emitModelMethodsDef(const Interface &interface); - void emitTraitDecl(const Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName); + void forwardDeclareInterface(const Interface &interface); void emitInterfaceDecl(const Interface &interface); + void emitInterfaceTraitDecl(const Interface &interface); /// The set of interface records to emit. std::vector defs; @@ -445,9 +445,16 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { os << "} // namespace " << ns << "\n"; } -void InterfaceGenerator::emitTraitDecl(const Interface &interface, - StringRef interfaceName, - StringRef interfaceTraitsName) { +void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + + os << "namespace detail {\n"; + + StringRef interfaceName = interface.getName(); + auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); os << llvm::formatv(" template \n" " struct {0}Trait : public ::mlir::{2}<{0}," " detail::{1}>::Trait<{3}> {{\n", @@ -494,6 +501,10 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface, os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; + os << "}// namespace detail\n"; + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; } static void emitInterfaceDeclMethods(const Interface &interface, @@ -517,6 +528,27 @@ static void emitInterfaceDeclMethods(const Interface &interface, os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n"; } +void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + + // Emit a forward declaration of the interface class so that it becomes usable + // in the signature of its methods. + std::string comments = tblgen::emitSummaryAndDescComments( + "", interface.getDescription().value_or("")); + if (!comments.empty()) { + os << comments << "\n"; + } + + StringRef interfaceName = interface.getName(); + os << "class " << interfaceName << ";\n"; + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); @@ -533,7 +565,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { if (!comments.empty()) { os << comments << "\n"; } - os << "class " << interfaceName << ";\n"; // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" @@ -603,10 +634,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { os << "};\n"; - os << "namespace detail {\n"; - emitTraitDecl(interface, interfaceName, interfaceTraitsName); - os << "}// namespace detail\n"; - for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; } @@ -619,10 +646,15 @@ bool InterfaceGenerator::emitInterfaceDecls() { llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { return lhs->getID() < rhs->getID(); }); + for (const Record *def : sortedDefs) + forwardDeclareInterface(Interface(def)); for (const Record *def : sortedDefs) emitInterfaceDecl(Interface(def)); + for (const Record *def : sortedDefs) + emitInterfaceTraitDecl(Interface(def)); for (const Record *def : sortedDefs) emitModelMethodsDef(Interface(def)); + return false; }