Skip to content

Improve keep_intermediate to enable keeping IR after each pass #1791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import List, Optional

from catalyst.logging import debug_logger, debug_logger_init
from catalyst.pipelines import CompileOptions
from catalyst.pipelines import CompileOptions, KeepIntermediateLevel
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory
from catalyst.utils.runtime_environment import get_cli_path, get_lib_path
Expand Down Expand Up @@ -342,7 +342,7 @@ def _options_to_cli_flags(options):
if not options.lower_to_llvm:
extra_args += [("--tool", "opt")]

if options.keep_intermediate:
if options.keep_intermediate >= KeepIntermediateLevel.BASIC:
extra_args += ["--keep-intermediate"]

if options.verbose:
Expand All @@ -351,6 +351,9 @@ def _options_to_cli_flags(options):
if options.async_qnodes: # pragma: nocover
extra_args += ["--async-qnodes"]

if options.keep_intermediate >= KeepIntermediateLevel.DEBUG:
extra_args += ["--save-ir-after-each=pass"]

return extra_args


Expand Down
14 changes: 10 additions & 4 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,16 @@ def qjit(
async_qnodes (bool): Experimental support for automatically executing
QNodes asynchronously, if supported by the device runtime.
target (str): the compilation target
keep_intermediate (bool): Whether or not to store the intermediate files throughout the
compilation. If ``True``, intermediate representations are available via the
:attr:`~.QJIT.mlir`, :attr:`~.QJIT.mlir_opt`, :attr:`~.QJIT.jaxpr`,
and :attr:`~.QJIT.qir`, representing different stages in the optimization process.
keep_intermediate (Union[str, int, bool]): Level controlling intermediate file generation.
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, is this the string or the python None?

- ``True`` or ``1`` or ``"basic"``: Standard intermediate files are kept.
- ``2`` or ``"debug"``: Standard intermediate files are kept, and IR is saved after
each pass.
If enabled, intermediate representations are available via the following attributes:
- :attr:`~.QJIT.jaxpr`: JAX program representation
- :attr:`~.QJIT.mlir`: MLIR representation after canonicalization
- :attr:`~.QJIT.mlir_opt`: MLIR representation after optimization
- :attr:`~.QJIT.qir`: QIR in LLVM IR form
verbose (bool): If ``True``, the tools and flags used by Catalyst behind the scenes are
printed out.
logfile (Optional[TextIOWrapper]): File object to write verbose messages to (default -
Expand Down
47 changes: 44 additions & 3 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

"""

import enum
import sys
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -37,6 +38,14 @@
from catalyst.utils.exceptions import CompileError


class KeepIntermediateLevel(enum.IntEnum):
"""Enum to control the level of intermediate file keeping."""

NONE = 0 # No intermediate files are kept.
BASIC = 1 # Standard intermediate files are kept.
DEBUG = 2 # Standard intermediate files are kept, and IR is saved after each pass.


# pylint: disable=too-many-instance-attributes
@dataclass
class CompileOptions:
Expand All @@ -47,8 +56,12 @@
Default is ``False``
logfile (Optional[TextIOWrapper]): the logfile to write output to.
Default is ``sys.stderr``
keep_intermediate (Optional[bool]): flag indicating whether to keep intermediate results.
Default is ``False``
keep_intermediate (Optional[Union[str, int, bool]]): Level controlling intermediate file
generation.
- ``False`` or ``0`` or ``"none"`` (default): No intermediate files are kept.
- ``True`` or ``1`` or ``"basic"``: Standard intermediate files are kept.
- ``2`` or ``"debug"``: Standard intermediate files are kept, and IR is saved after
each pass.
pipelines (Optional[List[Tuple[str,List[str]]]]): A list of tuples. The first entry of the
tuple corresponds to the name of a pipeline. The second entry of the tuple corresponds
to a list of MLIR passes.
Expand Down Expand Up @@ -79,7 +92,7 @@
verbose: Optional[bool] = False
logfile: Optional[TextIOWrapper] = sys.stderr
target: Optional[str] = "binary"
keep_intermediate: Optional[bool] = False
keep_intermediate: Optional[Union[str, int, bool, KeepIntermediateLevel]] = False
pipelines: Optional[List[Any]] = None
autograph: Optional[bool] = False
autograph_include: Optional[Iterable[str]] = ()
Expand All @@ -93,37 +106,65 @@
seed: Optional[int] = None
circuit_transform_pipeline: Optional[dict[str, dict[str, str]]] = None
pass_plugins: Optional[Set[Path]] = None
dialect_plugins: Optional[Set[Path]] = None

Check notice on line 109 in frontend/catalyst/pipelines.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/pipelines.py#L109

Too many branches (14/12) (too-many-branches)

def __post_init__(self):
# Convert keep_intermediate to Enum
if self.keep_intermediate is None:
self.keep_intermediate = KeepIntermediateLevel.NONE
elif isinstance(self.keep_intermediate, bool):
self.keep_intermediate = (
KeepIntermediateLevel.BASIC
if self.keep_intermediate
else KeepIntermediateLevel.NONE
)
elif isinstance(self.keep_intermediate, int):
try:
self.keep_intermediate = KeepIntermediateLevel(self.keep_intermediate)
except ValueError as e:
raise ValueError(
f"Invalid int for keep_intermediate: {self.keep_intermediate}. "
"Valid integers are 0, 1, 2."
) from e
elif isinstance(self.keep_intermediate, str):
try:
self.keep_intermediate = KeepIntermediateLevel[self.keep_intermediate.upper()]
except KeyError as e:
raise ValueError(
f"Invalid string for keep_intermediate: {self.keep_intermediate}. "
"Valid strings are 'none', 'basic', 'debug'."
) from e
elif not isinstance(self.keep_intermediate, KeepIntermediateLevel):
raise TypeError(f"Invalid type for keep_intermediate: {type(self.keep_intermediate)}.")

# Check that async runs must not be seeded
if self.async_qnodes and self.seed is not None:
raise CompileError(
"""
Seeding has no effect on asynchronous QNodes,
as the execution order of parallel runs is not guaranteed.
As such, seeding an asynchronous run is not supported.
"""
)

# Check that seed is 32-bit unsigned int
if (self.seed is not None) and (self.seed < 0 or self.seed > 2**32 - 1):
raise ValueError(
"""
Seed must be an unsigned 32-bit integer!
"""
)

# Make the format of static_argnums easier to handle.
static_argnums = self.static_argnums
if static_argnums is None:
self.static_argnums = ()
elif isinstance(static_argnums, int):
self.static_argnums = (static_argnums,)
elif isinstance(static_argnums, Iterable):
self.static_argnums = tuple(static_argnums)
if self.pass_plugins is None:
self.pass_plugins = set()

Check notice on line 167 in frontend/catalyst/pipelines.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/pipelines.py#L109-L167

Complex Method
if self.dialect_plugins is None:
self.dialect_plugins = set()

Expand Down
70 changes: 69 additions & 1 deletion frontend/test/pytest/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
import pytest

from catalyst import qjit
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver
from catalyst.compiler import CompileOptions, Compiler, LinkerDriver, _options_to_cli_flags
from catalyst.debug import instrumentation
from catalyst.pipelines import KeepIntermediateLevel
from catalyst.utils.exceptions import CompileError
from catalyst.utils.filesystem import Directory

Expand Down Expand Up @@ -92,6 +93,73 @@
capture = capture_result.out + capture_result.err
assert "[DIAGNOSTICS]" in capture

@pytest.mark.parametrize(
"input_value, expected_level",
[
(None, KeepIntermediateLevel.NONE),
(False, KeepIntermediateLevel.NONE),
(True, KeepIntermediateLevel.BASIC),
(0, KeepIntermediateLevel.NONE),
(1, KeepIntermediateLevel.BASIC),
(2, KeepIntermediateLevel.DEBUG),
("none", KeepIntermediateLevel.NONE),
("NONE", KeepIntermediateLevel.NONE),
("basic", KeepIntermediateLevel.BASIC),
("BASIC", KeepIntermediateLevel.BASIC),
("debug", KeepIntermediateLevel.DEBUG),
("DEBUG", KeepIntermediateLevel.DEBUG),
(KeepIntermediateLevel.NONE, KeepIntermediateLevel.NONE),
(KeepIntermediateLevel.BASIC, KeepIntermediateLevel.BASIC),
(KeepIntermediateLevel.DEBUG, KeepIntermediateLevel.DEBUG),
],
)
def test_keep_intermediate_levels_conversion(self, input_value, expected_level):
"""Test that various inputs for keep_intermediate are correctly converted to Enum."""
options = CompileOptions(keep_intermediate=input_value)
assert options.keep_intermediate == expected_level

@pytest.mark.parametrize(
"invalid_input, error_type, error_match",
[
(3, ValueError, "Invalid int for keep_intermediate: 3. Valid integers are 0, 1, 2."),
(-1, ValueError, "Invalid int for keep_intermediate: -1. Valid integers are 0, 1, 2."),
(
"invalid_string",
ValueError,
"Invalid string for keep_intermediate: invalid_string. Valid strings are 'none', 'basic', 'debug'.",

Check notice on line 129 in frontend/test/pytest/test_compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_compiler.py#L129

Line too long (116/100) (line-too-long)
),
(3.0, TypeError, "Invalid type for keep_intermediate: <class 'float'>."),
([], TypeError, "Invalid type for keep_intermediate: <class 'list'>."),
],
)
def test_keep_intermediate_invalid_inputs(self, invalid_input, error_type, error_match):
"""Test that invalid inputs for keep_intermediate raise appropriate errors."""
with pytest.raises(error_type, match=error_match):
CompileOptions(keep_intermediate=invalid_input)

def test_options_to_cli_flags_keep_intermediate_none(self):
"""Test _options_to_cli_flags with KeepIntermediateLevel.NONE."""
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.NONE)
flags = _options_to_cli_flags(options)
assert "--keep-intermediate" not in flags
assert "--save-ir-after-each=pass" not in flags

def test_options_to_cli_flags_keep_intermediate_basic(self):
"""Test _options_to_cli_flags with KeepIntermediateLevel.BASIC."""
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.BASIC)
flags = _options_to_cli_flags(options)
assert "--keep-intermediate" in flags
assert "--save-ir-after-each=pass" not in flags
assert flags.count("--save-ir-after-each=pass") <= 1
Comment on lines +132 to +133
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check this twice? 😅


def test_options_to_cli_flags_keep_intermediate_debug(self):
"""Test _options_to_cli_flags with KeepIntermediateLevel.DEBUG."""
options = CompileOptions(keep_intermediate=KeepIntermediateLevel.DEBUG)
flags = _options_to_cli_flags(options)
assert "--keep-intermediate" in flags
assert "--save-ir-after-each=pass" in flags
assert flags.count("--save-ir-after-each=pass") == 1
Comment on lines +140 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here



class TestCompilerWarnings:
"""Test compiler's warning messages."""
Expand Down
Loading