-
Notifications
You must be signed in to change notification settings - Fork 47
Improve keep_intermediate to enable keeping IR after each pass #1791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
ca8a36a
9344ac1
3869101
77bd5e9
cb9192b
5efe481
03c32d7
2b5a5af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
from typing import List, Optional | ||
|
||
from catalyst.logging import debug_logger, debug_logger_init | ||
from catalyst.pipelines import CompileOptions | ||
from catalyst.pipelines import CompileOptions, KeepIntermediateLevel | ||
from catalyst.utils.exceptions import CompileError | ||
from catalyst.utils.filesystem import Directory | ||
from catalyst.utils.runtime_environment import get_cli_path, get_lib_path | ||
|
@@ -342,7 +342,7 @@ def _options_to_cli_flags(options): | |
if not options.lower_to_llvm: | ||
extra_args += [("--tool", "opt")] | ||
|
||
if options.keep_intermediate: | ||
if options.keep_intermediate >= KeepIntermediateLevel.PIPELINE: | ||
extra_args += ["--keep-intermediate"] | ||
|
||
if options.verbose: | ||
|
@@ -351,6 +351,9 @@ def _options_to_cli_flags(options): | |
if options.async_qnodes: # pragma: nocover | ||
extra_args += ["--async-qnodes"] | ||
|
||
if options.keep_intermediate >= KeepIntermediateLevel.PASS: | ||
extra_args += ["--save-ir-after-each=pass"] | ||
|
||
Comment on lines
+354
to
+356
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes more sense to put the two |
||
return extra_args | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,10 +118,15 @@ def qjit( | |
async_qnodes (bool): Experimental support for automatically executing | ||
QNodes asynchronously, if supported by the device runtime. | ||
target (str): the compilation target | ||
keep_intermediate (bool): Whether or not to store the intermediate files throughout the | ||
compilation. If ``True``, intermediate representations are available via the | ||
:attr:`~.QJIT.mlir`, :attr:`~.QJIT.mlir_opt`, :attr:`~.QJIT.jaxpr`, | ||
and :attr:`~.QJIT.qir`, representing different stages in the optimization process. | ||
keep_intermediate (Union[str, int, bool]): Level controlling intermediate file generation. | ||
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, is this the string or the python |
||
- ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline. | ||
- ``2`` or ``"pass"``: Intermediate files are saved after each pass. | ||
If enabled, intermediate representations are available via the following attributes: | ||
- :attr:`~.QJIT.jaxpr`: JAX program representation | ||
- :attr:`~.QJIT.mlir`: MLIR representation after canonicalization | ||
- :attr:`~.QJIT.mlir_opt`: MLIR representation after optimization | ||
- :attr:`~.QJIT.qir`: QIR in LLVM IR form | ||
verbose (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are | ||
printed out. | ||
logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default - | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,8 +30,9 @@ | |
import pytest | ||
|
||
from catalyst import qjit | ||
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver | ||
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver, _options_to_cli_flags | ||
from catalyst.debug import instrumentation | ||
from catalyst.pipelines import KeepIntermediateLevel | ||
from catalyst.utils.exceptions import CompileError | ||
from catalyst.utils.filesystem import Directory | ||
|
||
|
@@ -92,6 +93,74 @@ def circuit(): | |
capture = capture_result.out + capture_result.err | ||
assert "[DIAGNOSTICS]" in capture | ||
|
||
@pytest.mark.parametrize( | ||
"input_value, expected_level", | ||
[ | ||
(None, KeepIntermediateLevel.NONE), | ||
(False, KeepIntermediateLevel.NONE), | ||
(True, KeepIntermediateLevel.PIPELINE), | ||
(0, KeepIntermediateLevel.NONE), | ||
(1, KeepIntermediateLevel.PIPELINE), | ||
(2, KeepIntermediateLevel.PASS), | ||
("none", KeepIntermediateLevel.NONE), | ||
("NONE", KeepIntermediateLevel.NONE), | ||
("pipeline", KeepIntermediateLevel.PIPELINE), | ||
("PIPELINE", KeepIntermediateLevel.PIPELINE), | ||
("pass", KeepIntermediateLevel.PASS), | ||
("PASS", KeepIntermediateLevel.PASS), | ||
(KeepIntermediateLevel.NONE, KeepIntermediateLevel.NONE), | ||
(KeepIntermediateLevel.PIPELINE, KeepIntermediateLevel.PIPELINE), | ||
(KeepIntermediateLevel.PASS, KeepIntermediateLevel.PASS), | ||
], | ||
) | ||
def test_keep_intermediate_levels_conversion(self, input_value, expected_level): | ||
"""Test that various inputs for keep_intermediate are correctly converted to Enum.""" | ||
options = CompileOptions(keep_intermediate=input_value) | ||
assert options.keep_intermediate == expected_level | ||
|
||
@pytest.mark.parametrize( | ||
"invalid_input, error_type, error_match", | ||
[ | ||
(3, ValueError, "Invalid int for keep_intermediate: 3. Valid integers are 0, 1, 2."), | ||
(-1, ValueError, "Invalid int for keep_intermediate: -1. Valid integers are 0, 1, 2."), | ||
( | ||
"invalid_string", | ||
ValueError, | ||
"Invalid string for keep_intermediate: invalid_string. Valid strings are 'none'," | ||
" 'pipeline', 'pass'.", | ||
), | ||
(3.0, TypeError, "Invalid type for keep_intermediate: <class 'float'>."), | ||
([], TypeError, "Invalid type for keep_intermediate: <class 'list'>."), | ||
], | ||
) | ||
def test_keep_intermediate_invalid_inputs(self, invalid_input, error_type, error_match): | ||
"""Test that invalid inputs for keep_intermediate raise appropriate errors.""" | ||
with pytest.raises(error_type, match=error_match): | ||
CompileOptions(keep_intermediate=invalid_input) | ||
|
||
def test_options_to_cli_flags_keep_intermediate_none(self): | ||
"""Test _options_to_cli_flags with KeepIntermediateLevel.NONE.""" | ||
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.NONE) | ||
flags = _options_to_cli_flags(options) | ||
assert "--keep-intermediate" not in flags | ||
assert "--save-ir-after-each=pass" not in flags | ||
|
||
def test_options_to_cli_flags_keep_intermediate_basic(self): | ||
"""Test _options_to_cli_flags with KeepIntermediateLevel.PIPELINE.""" | ||
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.PIPELINE) | ||
flags = _options_to_cli_flags(options) | ||
assert "--keep-intermediate" in flags | ||
assert "--save-ir-after-each=pass" not in flags | ||
assert flags.count("--save-ir-after-each=pass") <= 1 | ||
Comment on lines
+132
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why check this twice? 😅 |
||
|
||
def test_options_to_cli_flags_keep_intermediate_debug(self): | ||
"""Test _options_to_cli_flags with KeepIntermediateLevel.PASS.""" | ||
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.PASS) | ||
flags = _options_to_cli_flags(options) | ||
assert "--keep-intermediate" in flags | ||
assert "--save-ir-after-each=pass" in flags | ||
assert flags.count("--save-ir-after-each=pass") == 1 | ||
Comment on lines
+140
to
+141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here |
||
|
||
|
||
class TestCompilerWarnings: | ||
"""Test compiler's warning messages.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, is it the string
"none"
or the pythonNone
?