Skip to content

Commit e0f3b63

Browse files
committed
Add an MLIR "pattern catalog" generator.
This PR adds a new cmake build option MLIR_ENABLE_CATALOG_GENERATOR. When enabled, it attaches a listener to all RewritePatterns that emits the name of the operation being modified to a special file. When the MLIR test suite is run, these files can be combined into an index linking operations to the patterns that insert, modify, or replace them. This index is intended to be used to create a website that allows one to look up patterns from an operation name.
1 parent 4ad230b commit e0f3b63

File tree

5 files changed

+142
-2
lines changed

5 files changed

+142
-2
lines changed

mlir/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,24 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
202202
mlir_configure_python_dev_packages()
203203
endif()
204204

205+
#-------------------------------------------------------------------------------
206+
# MLIR Pattern Catalog Generator Configuration
207+
# Requires:
208+
# RTTI to be enabled (set with -DLLVM_ENABLE_RTTI=ON)
209+
# When enabled, causes all rewriter patterns to dump their type names and the
210+
# names of affected operations, which can be used to build a search index
211+
# mapping operations to patterns.
212+
#-------------------------------------------------------------------------------
213+
214+
set(MLIR_ENABLE_CATALOG_GENERATOR 0 CACHE BOOL
215+
"Enables construction of a catalog of rewrite patterns.")
216+
217+
if (MLIR_ENABLE_CATALOG_GENERATOR)
218+
message(STATUS "Enabling MLIR pattern catalog generator")
219+
add_definitions(-DMLIR_ENABLE_CATALOG_GENERATOR)
220+
add_definitions(-DLLVM_ENABLE_RTTI)
221+
endif()
222+
205223
set(CMAKE_INCLUDE_CURRENT_DIR ON)
206224

207225
include_directories(BEFORE
@@ -322,3 +340,4 @@ endif()
322340
if(MLIR_STANDALONE_BUILD)
323341
llvm_distribution_add_targets()
324342
endif()
343+

mlir/cmake/modules/MLIRConfig.cmake.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(MLIR_IRDL_TO_CPP_EXE "@MLIR_CONFIG_IRDL_TO_CPP_EXE@")
1616
set(MLIR_INSTALL_AGGREGATE_OBJECTS "@MLIR_INSTALL_AGGREGATE_OBJECTS@")
1717
set(MLIR_ENABLE_BINDINGS_PYTHON "@MLIR_ENABLE_BINDINGS_PYTHON@")
1818
set(MLIR_ENABLE_EXECUTION_ENGINE "@MLIR_ENABLE_EXECUTION_ENGINE@")
19+
set(MLIR_ENABLE_CATALOG_GENERATOR "@MLIR_ENABLE_CATALOG_GENERATOR@")
1920

2021
set_property(GLOBAL PROPERTY MLIR_ALL_LIBS "@MLIR_ALL_LIBS@")
2122
set_property(GLOBAL PROPERTY MLIR_DIALECT_LIBS "@MLIR_DIALECT_LIBS@")

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,80 @@ class RewriterBase : public OpBuilder {
475475
RewriterBase::Listener *rewriteListener;
476476
};
477477

478+
struct CatalogingListener : public RewriterBase::ForwardingListener {
479+
CatalogingListener(OpBuilder::Listener *listener,
480+
const std::string &patternName, raw_ostream &os,
481+
std::mutex &writeMutex)
482+
: RewriterBase::ForwardingListener(listener), patternName(patternName),
483+
os(os), writeMutex(writeMutex) {}
484+
485+
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
486+
{
487+
std::lock_guard<std::mutex> lock(writeMutex);
488+
os << patternName << " | notifyOperationInserted"
489+
<< " | " << op->getName() << "\n";
490+
os.flush();
491+
}
492+
ForwardingListener::notifyOperationInserted(op, previous);
493+
}
494+
495+
void notifyOperationModified(Operation *op) override {
496+
{
497+
std::lock_guard<std::mutex> lock(writeMutex);
498+
os << patternName << " | notifyOperationModified"
499+
<< " | " << op->getName() << "\n";
500+
os.flush();
501+
}
502+
ForwardingListener::notifyOperationModified(op);
503+
}
504+
505+
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
506+
{
507+
std::lock_guard<std::mutex> lock(writeMutex);
508+
os << patternName << " | notifyOperationReplaced (with op)"
509+
<< " | " << op->getName() << " | " << newOp->getName() << "\n";
510+
os.flush();
511+
}
512+
ForwardingListener::notifyOperationReplaced(op, newOp);
513+
}
514+
515+
void notifyOperationReplaced(Operation *op,
516+
ValueRange replacement) override {
517+
{
518+
std::lock_guard<std::mutex> lock(writeMutex);
519+
os << patternName << " | notifyOperationReplaced (with values)"
520+
<< " | " << op->getName() << "\n";
521+
os.flush();
522+
}
523+
ForwardingListener::notifyOperationReplaced(op, replacement);
524+
}
525+
526+
void notifyOperationErased(Operation *op) override {
527+
{
528+
std::lock_guard<std::mutex> lock(writeMutex);
529+
os << patternName << " | notifyOperationErased"
530+
<< " | " << op->getName() << "\n";
531+
os.flush();
532+
}
533+
ForwardingListener::notifyOperationErased(op);
534+
}
535+
536+
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
537+
{
538+
std::lock_guard<std::mutex> lock(writeMutex);
539+
os << patternName << " | notifyPatternBegin"
540+
<< " | " << op->getName() << "\n";
541+
os.flush();
542+
}
543+
ForwardingListener::notifyPatternBegin(pattern, op);
544+
}
545+
546+
private:
547+
const std::string &patternName;
548+
raw_ostream &os;
549+
std::mutex &writeMutex;
550+
};
551+
478552
/// Move the blocks that belong to "region" before the given position in
479553
/// another region "parent". The two regions must be different. The caller
480554
/// is responsible for creating or updating the operation transferring flow

mlir/lib/Rewrite/PatternApplicator.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,19 @@
1515
#include "ByteCode.h"
1616
#include "llvm/Support/Debug.h"
1717

18+
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
19+
#include "llvm/Support/FileSystem.h"
20+
#include "llvm/Support/raw_ostream.h"
21+
#include <cxxabi.h>
22+
#include <mutex>
23+
#endif
24+
1825
#define DEBUG_TYPE "pattern-application"
1926

27+
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
28+
static std::mutex catalogWriteMutex;
29+
#endif
30+
2031
using namespace mlir;
2132
using namespace mlir::detail;
2233

@@ -152,6 +163,16 @@ LogicalResult PatternApplicator::matchAndRewrite(
152163
unsigned anyIt = 0, anyE = anyOpPatterns.size();
153164
unsigned pdlIt = 0, pdlE = pdlMatches.size();
154165
LogicalResult result = failure();
166+
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
167+
std::error_code ec;
168+
llvm::raw_fd_ostream catalogOs("pattern_catalog.txt", ec,
169+
llvm::sys::fs::OF_Append);
170+
if (ec) {
171+
op->emitError("Failed to open pattern catalog file: " + ec.message());
172+
return failure();
173+
}
174+
#endif
175+
155176
do {
156177
// Find the next pattern with the highest benefit.
157178
const Pattern *bestPattern = nullptr;
@@ -206,14 +227,38 @@ LogicalResult PatternApplicator::matchAndRewrite(
206227
} else {
207228
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
208229
<< bestPattern->getDebugName() << "\"\n");
209-
210230
const auto *pattern =
211231
static_cast<const RewritePattern *>(bestPattern);
212-
result = pattern->matchAndRewrite(op, rewriter);
213232

233+
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
234+
OpBuilder::Listener *oldListener = rewriter.getListener();
235+
int status;
236+
const char *mangledPatternName = typeid(*pattern).name();
237+
char *demangled = abi::__cxa_demangle(mangledPatternName, nullptr,
238+
nullptr, &status);
239+
std::string demangledPatternName;
240+
if (status == 0 && demangled) {
241+
demangledPatternName = demangled;
242+
free(demangled);
243+
} else {
244+
// Fallback in case demangling fails.
245+
demangledPatternName = mangledPatternName;
246+
}
247+
248+
RewriterBase::CatalogingListener *catalogingListener =
249+
new RewriterBase::CatalogingListener(
250+
oldListener, demangledPatternName, catalogOs,
251+
catalogWriteMutex);
252+
rewriter.setListener(catalogingListener);
253+
#endif
254+
result = pattern->matchAndRewrite(op, rewriter);
214255
LLVM_DEBUG(llvm::dbgs()
215256
<< "\"" << bestPattern->getDebugName() << "\" result "
216257
<< succeeded(result) << "\n");
258+
#ifdef MLIR_ENABLE_CATALOG_GENERATOR
259+
rewriter.setListener(oldListener);
260+
delete catalogingListener;
261+
#endif
217262
}
218263

219264
// Process the result of the pattern application.

utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ expand_template(
4444
"@MLIR_ENABLE_SPIRV_CPU_RUNNER@": "0",
4545
"@MLIR_ENABLE_VULKAN_RUNNER@": "0",
4646
"@MLIR_ENABLE_BINDINGS_PYTHON@": "0",
47+
"@MLIR_ENABLE_CATALOG_GENERATOR@": "0",
4748
"@MLIR_RUN_AMX_TESTS@": "0",
4849
"@MLIR_RUN_ARM_SVE_TESTS@": "0",
4950
"@MLIR_RUN_ARM_SME_TESTS@": "0",

0 commit comments

Comments
 (0)