Skip to content

Commit c8b837a

Browse files
authored
[MLIR][Python] Add the --mlir-print-ir-tree-dir to the C and Python API (llvm#117339)
1 parent e2519b6 commit c8b837a

File tree

4 files changed

+69
-9
lines changed

4 files changed

+69
-9
lines changed

mlir/include/mlir-c/Pass.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult
7575
mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op);
7676

7777
/// Enable IR printing.
78+
/// The treePrintingPath argument is an optional path to a directory
79+
/// where the dumps will be produced. If it isn't provided then dumps
80+
/// are produced to stderr.
7881
MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(
7982
MlirPassManager passManager, bool printBeforeAll, bool printAfterAll,
8083
bool printModuleScope, bool printAfterOnlyOnChange,
81-
bool printAfterOnlyOnFailure);
84+
bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath);
8285

8386
/// Enable / disable verify-each.
8487
MLIR_CAPI_EXPORTED void

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,21 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
7676
"enable_ir_printing",
7777
[](PyPassManager &passManager, bool printBeforeAll,
7878
bool printAfterAll, bool printModuleScope, bool printAfterChange,
79-
bool printAfterFailure) {
79+
bool printAfterFailure,
80+
std::optional<std::string> optionalTreePrintingPath) {
81+
std::string treePrintingPath = "";
82+
if (optionalTreePrintingPath.has_value())
83+
treePrintingPath = optionalTreePrintingPath.value();
8084
mlirPassManagerEnableIRPrinting(
8185
passManager.get(), printBeforeAll, printAfterAll,
82-
printModuleScope, printAfterChange, printAfterFailure);
86+
printModuleScope, printAfterChange, printAfterFailure,
87+
mlirStringRefCreate(treePrintingPath.data(),
88+
treePrintingPath.size()));
8389
},
8490
"print_before_all"_a = false, "print_after_all"_a = true,
8591
"print_module_scope"_a = false, "print_after_change"_a = false,
8692
"print_after_failure"_a = false,
93+
"tree_printing_dir_path"_a = py::none(),
8794
"Enable IR printing, default as mlir-print-ir-after-all.")
8895
.def(
8996
"enable_verifier",

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
4848
bool printBeforeAll, bool printAfterAll,
4949
bool printModuleScope,
5050
bool printAfterOnlyOnChange,
51-
bool printAfterOnlyOnFailure) {
51+
bool printAfterOnlyOnFailure,
52+
MlirStringRef treePrintingPath) {
5253
auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
5354
return printBeforeAll;
5455
};
5556
auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
5657
return printAfterAll;
5758
};
58-
return unwrap(passManager)
59-
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
60-
printModuleScope, printAfterOnlyOnChange,
61-
printAfterOnlyOnFailure);
59+
if (unwrap(treePrintingPath).empty())
60+
return unwrap(passManager)
61+
->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
62+
printModuleScope, printAfterOnlyOnChange,
63+
printAfterOnlyOnFailure);
64+
65+
unwrap(passManager)
66+
->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
67+
printModuleScope, printAfterOnlyOnChange,
68+
printAfterOnlyOnFailure,
69+
unwrap(treePrintingPath));
6270
}
6371

6472
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {

mlir/test/python/pass_manager.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# RUN: %PYTHON %s 2>&1 | FileCheck %s
22

3-
import gc, sys
3+
import gc, os, sys, tempfile
44
from mlir.ir import *
55
from mlir.passmanager import *
66
from mlir.dialects.func import FuncOp
@@ -340,3 +340,45 @@ def testPrintIrBeforeAndAfterAll():
340340
# CHECK: }
341341
# CHECK: }
342342
pm.run(module)
343+
344+
345+
# CHECK-LABEL: TEST: testPrintIrTree
346+
@run
347+
def testPrintIrTree():
348+
with Context() as ctx:
349+
module = ModuleOp.parse(
350+
"""
351+
module {
352+
func.func @main() {
353+
%0 = arith.constant 10
354+
return
355+
}
356+
}
357+
"""
358+
)
359+
pm = PassManager.parse("builtin.module(canonicalize)")
360+
ctx.enable_multithreading(False)
361+
pm.enable_ir_printing()
362+
# CHECK-LABEL: // Tree printing begin
363+
# CHECK: \-- builtin_module_no-symbol-name
364+
# CHECK: \-- 0_canonicalize.mlir
365+
# CHECK-LABEL: // Tree printing end
366+
pm.run(module)
367+
log("// Tree printing begin")
368+
with tempfile.TemporaryDirectory() as temp_dir:
369+
pm.enable_ir_printing(tree_printing_dir_path=temp_dir)
370+
pm.run(module)
371+
372+
def print_file_tree(directory, prefix=""):
373+
entries = sorted(os.listdir(directory))
374+
for i, entry in enumerate(entries):
375+
path = os.path.join(directory, entry)
376+
connector = "\-- " if i == len(entries) - 1 else "|-- "
377+
log(f"{prefix}{connector}{entry}")
378+
if os.path.isdir(path):
379+
print_file_tree(
380+
path, prefix + (" " if i == len(entries) - 1 else "│ ")
381+
)
382+
383+
print_file_tree(temp_dir)
384+
log("// Tree printing end")

0 commit comments

Comments
 (0)