Skip to content

Improve keep_intermediate to enable keeping IR after each pass #1791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@
performance by eliminating indirect conversion.
[(#1738)](https://github.com/PennyLaneAI/catalyst/pull/1738)

* The `keep_intermediate` argument in the `qjit` decorator now accepts a new value that allows for
saving intermediate files after each pass. The updated possible options for this argument are:
* `False` or `0` or `"none"`: No intermediate files are kept.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, is it the string "none" or the python None?

* `True` or `1` or `"pipeline"`: Intermediate files are saved after each pipeline.
* `2` or `"pass"`: Intermediate files are saved after each pass.
The default value is `False`.
[(#1791)](https://github.com/PennyLaneAI/catalyst/pull/1791)

<h3>Breaking changes 💔</h3>

* (Device Developers Only) The `QuantumDevice` interface in the Catalyst Runtime plugin system
Expand Down
7 changes: 5 additions & 2 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import List, Optional

from catalyst.logging import debug_logger, debug_logger_init
from catalyst.pipelines import CompileOptions
from catalyst.pipelines import CompileOptions, KeepIntermediateLevel
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory
from catalyst.utils.runtime_environment import get_cli_path, get_lib_path
Expand Down Expand Up @@ -342,7 +342,7 @@ def _options_to_cli_flags(options):
if not options.lower_to_llvm:
extra_args += [("--tool", "opt")]

if options.keep_intermediate:
if options.keep_intermediate >= KeepIntermediateLevel.PIPELINE:
extra_args += ["--keep-intermediate"]

if options.verbose:
Expand All @@ -351,6 +351,9 @@ def _options_to_cli_flags(options):
if options.async_qnodes: # pragma: nocover
extra_args += ["--async-qnodes"]

if options.keep_intermediate >= KeepIntermediateLevel.PASS:
extra_args += ["--save-ir-after-each=pass"]

Comment on lines +354 to +356
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense to put the two if options.keep_intermediate blocks together

return extra_args


Expand Down
13 changes: 9 additions & 4 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,15 @@ def qjit(
async_qnodes (bool): Experimental support for automatically executing
QNodes asynchronously, if supported by the device runtime.
target (str): the compilation target
keep_intermediate (bool): Whether or not to store the intermediate files throughout the
compilation. If ``True``, intermediate representations are available via the
:attr:`~.QJIT.mlir`, :attr:`~.QJIT.mlir_opt`, :attr:`~.QJIT.jaxpr`,
and :attr:`~.QJIT.qir`, representing different stages in the optimization process.
keep_intermediate (Union[str, int, bool]): Level controlling intermediate file generation.
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, is this the string or the python None?

- ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline.
- ``2`` or ``"pass"``: Intermediate files are saved after each pass.
If enabled, intermediate representations are available via the following attributes:
- :attr:`~.QJIT.jaxpr`: JAX program representation
- :attr:`~.QJIT.mlir`: MLIR representation after canonicalization
- :attr:`~.QJIT.mlir_opt`: MLIR representation after optimization
- :attr:`~.QJIT.qir`: QIR in LLVM IR form
verbose (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are
printed out.
logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default -
Expand Down
41 changes: 38 additions & 3 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

"""

import enum
import sys
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -37,6 +38,34 @@
from catalyst.utils.exceptions import CompileError


class KeepIntermediateLevel(enum.IntEnum):
"""Enum to control the level of intermediate file keeping."""

NONE = 0 # No intermediate files are kept.
PIPELINE = 1 # Intermediate files are saved after each pipeline.
PASS = 2 # Intermediate files are saved after each pass.


def _parse_keep_intermediate(
level: Union[str, int, bool],
) -> KeepIntermediateLevel:
"""Parse the keep_intermediate value into a KeepIntermediateLevel enum."""
match level:
case 0 | 1 | 2:
return KeepIntermediateLevel(level)
case "none":
return KeepIntermediateLevel.NONE
case "pipeline":
return KeepIntermediateLevel.PIPELINE
case "pass":
return KeepIntermediateLevel.PASS
case _:
raise ValueError(
f"Invalid value for keep_intermediate: {level}. "
"Valid values are True, False, 0, 1, 2, 'none', 'pipeline', 'pass'."
)
Comment on lines +62 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, does the python None come to this error case? Since the above is matching for the string

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also are booleans parsed here? 🤔



# pylint: disable=too-many-instance-attributes
@dataclass
class CompileOptions:
Expand All @@ -47,8 +76,11 @@ class CompileOptions:
Default is ``False``
logfile (Optional[TextIOWrapper]): the logfile to write output to.
Default is ``sys.stderr``
keep_intermediate (Optional[bool]): flag indicating whether to keep intermediate results.
Default is ``False``
keep_intermediate (Optional[Union[str, int, bool]]): Level controlling intermediate file
generation.
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept.
- ``True`` or ``1`` or ``"pipeline"``: Intermediate files are saved after each pipeline.
- ``2`` or ``"pass"``: Intermediate files are saved after each pass.
pipelines (Optional[List[Tuple[str,List[str]]]]): A list of tuples. The first entry of the
tuple corresponds to the name of a pipeline. The second entry of the tuple corresponds
to a list of MLIR passes.
Expand Down Expand Up @@ -79,7 +111,7 @@ class CompileOptions:
verbose: Optional[bool] = False
logfile: Optional[TextIOWrapper] = sys.stderr
target: Optional[str] = "binary"
keep_intermediate: Optional[bool] = False
keep_intermediate: Optional[Union[str, int, bool, KeepIntermediateLevel]] = False
pipelines: Optional[List[Any]] = None
autograph: Optional[bool] = False
autograph_include: Optional[Iterable[str]] = ()
Expand All @@ -96,6 +128,9 @@ class CompileOptions:
dialect_plugins: Optional[Set[Path]] = None

def __post_init__(self):
# Convert keep_intermediate to Enum
self.keep_intermediate = _parse_keep_intermediate(self.keep_intermediate)

# Check that async runs must not be seeded
if self.async_qnodes and self.seed is not None:
raise CompileError(
Expand Down
50 changes: 49 additions & 1 deletion frontend/test/pytest/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
import pytest

from catalyst import qjit
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver, _options_to_cli_flags
from catalyst.debug import instrumentation
from catalyst.pipelines import KeepIntermediateLevel
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory

Expand Down Expand Up @@ -92,6 +93,53 @@ def circuit():
capture = capture_result.out + capture_result.err
assert "[DIAGNOSTICS]" in capture

@pytest.mark.parametrize(
"input_value, expected_level",
[
(False, KeepIntermediateLevel.NONE),
(True, KeepIntermediateLevel.PIPELINE),
(0, KeepIntermediateLevel.NONE),
(1, KeepIntermediateLevel.PIPELINE),
(2, KeepIntermediateLevel.PASS),
("none", KeepIntermediateLevel.NONE),
("pipeline", KeepIntermediateLevel.PIPELINE),
("pass", KeepIntermediateLevel.PASS),
],
)
def test_keep_intermediate_levels_conversion(self, input_value, expected_level):
"""Test that various inputs for keep_intermediate are correctly converted to Enum."""
options = CompileOptions(keep_intermediate=input_value)
assert options.keep_intermediate == expected_level

@pytest.mark.parametrize("invalid_input", [3, -1, "invalid_string", 3.0, []])
def test_keep_intermediate_invalid_inputs(self, invalid_input):
"""Test that invalid inputs for keep_intermediate raise appropriate errors."""
with pytest.raises(ValueError, match="Invalid value for keep_intermediate:"):
CompileOptions(keep_intermediate=invalid_input)

def test_options_to_cli_flags_keep_intermediate_none(self):
"""Test _options_to_cli_flags with KeepIntermediateLevel.NONE."""
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.NONE)
flags = _options_to_cli_flags(options)
assert "--keep-intermediate" not in flags
assert "--save-ir-after-each=pass" not in flags

def test_options_to_cli_flags_keep_intermediate_basic(self):
"""Test _options_to_cli_flags with KeepIntermediateLevel.PIPELINE."""
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.PIPELINE)
flags = _options_to_cli_flags(options)
assert "--keep-intermediate" in flags
assert "--save-ir-after-each=pass" not in flags
assert flags.count("--save-ir-after-each=pass") <= 1
Comment on lines +132 to +133
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check this twice? 😅


def test_options_to_cli_flags_keep_intermediate_debug(self):
"""Test _options_to_cli_flags with KeepIntermediateLevel.PASS."""
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.PASS)
flags = _options_to_cli_flags(options)
assert "--keep-intermediate" in flags
assert "--save-ir-after-each=pass" in flags
assert flags.count("--save-ir-after-each=pass") == 1
Comment on lines +140 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here



class TestCompilerWarnings:
"""Test compiler's warning messages."""
Expand Down