Skip to content

Support Static and Dynamic Variables in PLxPR Programs with QJIT #1810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2025
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@
performance by eliminating indirect conversion.
[(#1738)](https://github.com/PennyLaneAI/catalyst/pull/1738)

* `static_argnums` on `qjit` can now be specified with program capture through PLxPR.
[(#1810)](https://github.com/PennyLaneAI/catalyst/pull/1810)

<h3>Breaking changes 💔</h3>

* (Device Developers Only) The `QuantumDevice` interface in the Catalyst Runtime plugin system
Expand Down
32 changes: 29 additions & 3 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,15 +492,33 @@ def handle_measure_in_basis(self, angle, wire, plane, reset, postselect):


# pylint: disable=too-many-positional-arguments
def trace_from_pennylane(fn, static_argnums, abstracted_axes, sig, kwargs, debug_info=None):
def trace_from_pennylane(
fn, static_argnums, dynamic_args, abstracted_axes, sig, kwargs, debug_info=None
):
"""Capture the JAX program representation (JAXPR) of the wrapped function, using
PL capure module.

Args:
args (Iterable): arguments to use for program capture
fn(Callable): the user function to be traced
static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the
positions of static arguments.
dynamic_args(Seqence[Any]): the abstract values of the dynamic arguments.
abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]):
An experimental option to specify dynamic tensor shapes.
This option affects the compilation of the annotated function.
Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors
with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section
below.
sig(Sequence[Any]): a tuple indicating the argument signature of the function. Static arguments
are indicated with their literal values, and dynamic arguments are indicated by abstract
values.
kwargs(Dict[str, Any]): keyword argumemts to the function.
debug_info(jax.api_util.debug_info): a source debug information object required by jaxprs.

Returns:
ClosedJaxpr: captured JAXPR
Tuple[Tuple[ShapedArray, bool]]: the return type of the captured JAXPR.
The boolean indicates whether each result is a value returned by the user function.
PyTreeDef: PyTree metadata of the function output
Tuple[Any]: the dynamic argument signature
"""
Expand All @@ -515,7 +533,15 @@ def trace_from_pennylane(fn, static_argnums, abstracted_axes, sig, kwargs, debug

args = sig

if isinstance(fn, qml.QNode) and static_argnums:
# `make_jaxpr2` sees the qnode
# The static_argnum on the wrapped function takes precedence over the
# one in `make_jaxpr`
# https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231
# Therefore we need to coordinate them manually
fn.static_argnums = static_argnums

plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
jaxpr = from_plxpr(plxpr)(*args, **kwargs)
jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs)

return jaxpr, out_type, out_treedef, sig
1 change: 1 addition & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,7 @@ def capture(self, args, **kwargs):
return trace_from_pennylane(
self.user_function,
static_argnums,
dynamic_args,
abstracted_axes,
full_sig,
kwargs,
Expand Down
53 changes: 53 additions & 0 deletions frontend/test/pytest/test_capture_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,3 +1500,56 @@ def loop_0(i):
return qml.sample()

assert jnp.allclose(circuit(), capture_result)

def test_static_variable_qnode(self, backend):
"""Test the integration for a circuit with a static variable."""

qml.capture.enable()

# Basic test
@qjit(static_argnums=(0,))
@qml.qnode(qml.device(backend, wires=1))
def captured_circuit_1(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=0)
return qml.expval(qml.PauliZ(0))

result_1 = captured_circuit_1(1.5, 2.0)
captured_circuit_1_mlir = captured_circuit_1.mlir
assert "%cst = arith.constant 1.5" in captured_circuit_1_mlir
assert 'quantum.custom "RX"(%cst)' in captured_circuit_1_mlir
assert "%cst = arith.constant 2.0" not in captured_circuit_1_mlir

# Test that qjit static_argnums takes precedence over the one on the qnode
@qjit(static_argnums=1)
@qml.qnode(qml.device(backend, wires=1), static_argnums=0) # should be ignored
def captured_circuit_2(x, y):
qml.RX(x, wires=0)
qml.RY(y, wires=0)
return qml.expval(qml.PauliZ(0))

result_2 = captured_circuit_2(1.5, 2.0)
captured_circuit_2_mlir = captured_circuit_2.mlir
assert "%cst = arith.constant 2.0" in captured_circuit_2_mlir
assert 'quantum.custom "RY"(%cst)' in captured_circuit_2_mlir
assert "%cst = arith.constant 1.5" not in captured_circuit_2_mlir

assert result_1 == result_2

# Test under a non qnode workflow function
@qjit(static_argnums=(0,))
def workflow(x, y):
@qml.qnode(qml.device(backend, wires=1))
def c():
qml.RX(x, wires=0)
qml.RY(y, wires=0)
return qml.expval(qml.PauliZ(0))

return c()

_ = workflow(1.5, 2.0)
captured_circuit_3_mlir = workflow.mlir
assert "%cst = arith.constant 1.5" in captured_circuit_3_mlir
assert 'quantum.custom "RX"(%cst)' in captured_circuit_3_mlir

qml.capture.disable()