-
Notifications
You must be signed in to change notification settings - Fork 47
Improve keep_intermediate to enable keeping IR after each pass #1791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ca8a36a
9344ac1
3869101
77bd5e9
cb9192b
5efe481
03c32d7
2b5a5af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
from typing import List, Optional | ||
|
||
from catalyst.logging import debug_logger, debug_logger_init | ||
from catalyst.pipelines import CompileOptions | ||
from catalyst.pipelines import CompileOptions, KeepIntermediateLevel | ||
from catalyst.utils.exceptions import CompileError | ||
from catalyst.utils.filesystem import Directory | ||
from catalyst.utils.runtime_environment import get_cli_path, get_lib_path | ||
|
@@ -342,7 +342,7 @@ def _options_to_cli_flags(options): | |
if not options.lower_to_llvm: | ||
extra_args += [("--tool", "opt")] | ||
|
||
if options.keep_intermediate: | ||
if options.keep_intermediate >= KeepIntermediateLevel.PIPELINE: | ||
extra_args += ["--keep-intermediate"] | ||
|
||
if options.verbose: | ||
|
@@ -351,6 +351,9 @@ def _options_to_cli_flags(options): | |
if options.async_qnodes: # pragma: nocover | ||
extra_args += ["--async-qnodes"] | ||
|
||
if options.keep_intermediate >= KeepIntermediateLevel.PASS: | ||
extra_args += ["--save-ir-after-each=pass"] | ||
|
||
Comment on lines
+354
to
+356
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes more sense to put the two |
||
return extra_args | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,10 +118,15 @@ def qjit( | |
async_qnodes (bool): Experimental support for automatically executing | ||
QNodes asynchronously, if supported by the device runtime. | ||
target (str): the compilation target | ||
keep_intermediate (bool): Whether or not to store the intermediate files throughout the | ||
compilation. If ``True``, intermediate representations are available via the | ||
:attr:`~.QJIT.mlir`, :attr:`~.QJIT.mlir_opt`, :attr:`~.QJIT.jaxpr`, | ||
and :attr:`~.QJIT.qir`, representing different stages in the optimization process. | ||
keep_intermediate (Union[str, int, bool]): Level controlling intermediate file generation. | ||
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, is this the string or the python |
||
- ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline. | ||
- ``2`` or ``"pass"``: Intermediate files are saved after each pass. | ||
If enabled, intermediate representations are available via the following attributes: | ||
- :attr:`~.QJIT.jaxpr`: JAX program representation | ||
- :attr:`~.QJIT.mlir`: MLIR representation after canonicalization | ||
- :attr:`~.QJIT.mlir_opt`: MLIR representation after optimization | ||
- :attr:`~.QJIT.qir`: QIR in LLVM IR form | ||
verbose (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are | ||
printed out. | ||
logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default - | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
|
||
""" | ||
|
||
import enum | ||
import sys | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
|
@@ -37,6 +38,34 @@ | |
from catalyst.utils.exceptions import CompileError | ||
|
||
|
||
class KeepIntermediateLevel(enum.IntEnum): | ||
"""Enum to control the level of intermediate file keeping.""" | ||
|
||
NONE = 0 # No intermediate files are kept. | ||
PIPELINE = 1 # Intermediate files are saved after each pipeline. | ||
PASS = 2 # Intermediate files are saved after each pass. | ||
|
||
|
||
def _parse_keep_intermediate( | ||
level: Union[str, int, bool], | ||
) -> KeepIntermediateLevel: | ||
"""Parse the keep_intermediate value into a KeepIntermediateLevel enum.""" | ||
match level: | ||
case 0 | 1 | 2: | ||
return KeepIntermediateLevel(level) | ||
case "none": | ||
return KeepIntermediateLevel.NONE | ||
case "pipeline": | ||
return KeepIntermediateLevel.PIPELINE | ||
case "pass": | ||
return KeepIntermediateLevel.PASS | ||
case _: | ||
raise ValueError( | ||
f"Invalid value for keep_intermediate: {level}. " | ||
"Valid values are True, False, 0, 1, 2, 'none', 'pipeline', 'pass'." | ||
) | ||
Comment on lines
+62
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, does the python There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also are booleans parsed here? 🤔 |
||
|
||
|
||
# pylint: disable=too-many-instance-attributes | ||
@dataclass | ||
class CompileOptions: | ||
|
@@ -47,8 +76,11 @@ class CompileOptions: | |
Default is ``False`` | ||
logfile (Optional[TextIOWrapper]): the logfile to write output to. | ||
Default is ``sys.stderr`` | ||
keep_intermediate (Optional[bool]): flag indicating whether to keep intermediate results. | ||
Default is ``False`` | ||
keep_intermediate (Optional[Union[str, int, bool]]): Level controlling intermediate file | ||
generation. | ||
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept. | ||
- ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline. | ||
- ``2`` or ``"pass"``: Intermediate files are saved after each pass. | ||
pipelines (Optional[List[Tuple[str,List[str]]]]): A list of tuples. The first entry of the | ||
tuple corresponds to the name of a pipeline. The second entry of the tuple corresponds | ||
to a list of MLIR passes. | ||
|
@@ -79,7 +111,7 @@ class CompileOptions: | |
verbose: Optional[bool] = False | ||
logfile: Optional[TextIOWrapper] = sys.stderr | ||
target: Optional[str] = "binary" | ||
keep_intermediate: Optional[bool] = False | ||
keep_intermediate: Optional[Union[str, int, bool, KeepIntermediateLevel]] = False | ||
pipelines: Optional[List[Any]] = None | ||
autograph: Optional[bool] = False | ||
autograph_include: Optional[Iterable[str]] = () | ||
|
@@ -96,6 +128,9 @@ class CompileOptions: | |
dialect_plugins: Optional[Set[Path]] = None | ||
|
||
def __post_init__(self): | ||
# Convert keep_intermediate to Enum | ||
self.keep_intermediate = _parse_keep_intermediate(self.keep_intermediate) | ||
|
||
# Check that async runs must not be seeded | ||
if self.async_qnodes and self.seed is not None: | ||
raise CompileError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,8 +30,9 @@ | |
import pytest | ||
|
||
from catalyst import qjit | ||
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver | ||
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver, _options_to_cli_flags | ||
from catalyst.debug import instrumentation | ||
from catalyst.pipelines import KeepIntermediateLevel | ||
from catalyst.utils.exceptions import CompileError | ||
from catalyst.utils.filesystem import Directory | ||
|
||
|
@@ -92,6 +93,53 @@ def circuit(): | |
capture = capture_result.out + capture_result.err | ||
assert "[DIAGNOSTICS]" in capture | ||
|
||
@pytest.mark.parametrize( | ||
"input_value, expected_level", | ||
[ | ||
(False, KeepIntermediateLevel.NONE), | ||
(True, KeepIntermediateLevel.PIPELINE), | ||
(0, KeepIntermediateLevel.NONE), | ||
(1, KeepIntermediateLevel.PIPELINE), | ||
(2, KeepIntermediateLevel.PASS), | ||
("none", KeepIntermediateLevel.NONE), | ||
("pipeline", KeepIntermediateLevel.PIPELINE), | ||
("pass", KeepIntermediateLevel.PASS), | ||
], | ||
) | ||
def test_keep_intermediate_levels_conversion(self, input_value, expected_level): | ||
"""Test that various inputs for keep_intermediate are correctly converted to Enum.""" | ||
options = CompileOptions(keep_intermediate=input_value) | ||
assert options.keep_intermediate == expected_level | ||
|
||
@pytest.mark.parametrize("invalid_input", [3, -1, "invalid_string", 3.0, []]) | ||
def test_keep_intermediate_invalid_inputs(self, invalid_input): | ||
"""Test that invalid inputs for keep_intermediate raise appropriate errors.""" | ||
with pytest.raises(ValueError, match="Invalid value for keep_intermediate:"): | ||
CompileOptions(keep_intermediate=invalid_input) | ||
|
||
def test_options_to_cli_flags_keep_intermediate_none(self): | ||
"""Test _options_to_cli_flags with KeepIntermediateLevel.NONE.""" | ||
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.NONE) | ||
flags = _options_to_cli_flags(options) | ||
assert "--keep-intermediate" not in flags | ||
assert "--save-ir-after-each=pass" not in flags | ||
|
||
def test_options_to_cli_flags_keep_intermediate_basic(self): | ||
"""Test _options_to_cli_flags with KeepIntermediateLevel.PIPELINE.""" | ||
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.PIPELINE) | ||
flags = _options_to_cli_flags(options) | ||
assert "--keep-intermediate" in flags | ||
assert "--save-ir-after-each=pass" not in flags | ||
assert flags.count("--save-ir-after-each=pass") <= 1 | ||
Comment on lines
+132
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why check this twice? 😅 |
||
|
||
def test_options_to_cli_flags_keep_intermediate_debug(self): | ||
"""Test _options_to_cli_flags with KeepIntermediateLevel.PASS.""" | ||
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.PASS) | ||
flags = _options_to_cli_flags(options) | ||
assert "--keep-intermediate" in flags | ||
assert "--save-ir-after-each=pass" in flags | ||
assert flags.count("--save-ir-after-each=pass") == 1 | ||
Comment on lines
+140
to
+141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also here |
||
|
||
|
||
class TestCompilerWarnings: | ||
"""Test compiler's warning messages.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, is it the string
"none"
or the pythonNone
?