From bd5ff42ba97c826be8769eb8f942c173e4589ee0 Mon Sep 17 00:00:00 2001 From: sengthai Date: Mon, 16 Jun 2025 09:48:41 -0400 Subject: [PATCH 01/10] Enhance function to support dynamic arguments --- frontend/catalyst/from_plxpr.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index fa468a20a..4daae5c55 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -45,6 +45,7 @@ extract_backend_info, get_device_capabilities, ) +from catalyst.tracing.type_signatures import filter_static_args from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( AbstractQbit, @@ -722,7 +723,9 @@ def handle_measure_in_basis(self, angle, wire, plane, reset, postselect): # pylint: disable=too-many-positional-arguments -def trace_from_pennylane(fn, static_argnums, abstracted_axes, sig, kwargs, debug_info=None): +def trace_from_pennylane( + fn, static_argnums, dynamic_argnums, abstracted_axes, sig, kwargs, debug_info=None +): """Capture the JAX program representation (JAXPR) of the wrapped function, using PL capure module. @@ -746,6 +749,6 @@ def trace_from_pennylane(fn, static_argnums, abstracted_axes, sig, kwargs, debug args = sig plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) - jaxpr = from_plxpr(plxpr)(*args, **kwargs) + jaxpr = from_plxpr(plxpr)(*dynamic_argnums, **kwargs) return jaxpr, out_type, out_treedef, sig From 6d86929247f7790d44105045da4b5fac8dee006a Mon Sep 17 00:00:00 2001 From: sengthai Date: Mon, 16 Jun 2025 09:55:08 -0400 Subject: [PATCH 02/10] Apply changed --- frontend/catalyst/from_plxpr.py | 1 - frontend/catalyst/jit.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 4daae5c55..3287bafe1 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -45,7 +45,6 @@ extract_backend_info, get_device_capabilities, ) -from catalyst.tracing.type_signatures import filter_static_args from catalyst.jax_extras import jaxpr_pad_consts, make_jaxpr2, transient_jax_config from catalyst.jax_primitives import ( AbstractQbit, diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 7c260ceac..db5b51904 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -725,6 +725,7 @@ def capture(self, args, **kwargs): return trace_from_pennylane( self.user_function, static_argnums, + dynamic_args, abstracted_axes, full_sig, kwargs, From df483ab831a6898b7f8b49e2fa6ec03dccbe1d28 Mon Sep 17 00:00:00 2001 From: sengthai Date: Mon, 16 Jun 2025 12:09:16 -0400 Subject: [PATCH 03/10] Added test --- .../test/pytest/test_capture_integration.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/frontend/test/pytest/test_capture_integration.py b/frontend/test/pytest/test_capture_integration.py index 430d2ea4b..248216283 100644 --- a/frontend/test/pytest/test_capture_integration.py +++ b/frontend/test/pytest/test_capture_integration.py @@ -1500,3 +1500,36 @@ def loop_0(i): return qml.sample() assert jnp.allclose(circuit(), capture_result) + + def test_static_variable_qnode(self, backend): + """Test the integration for a circuit with a static variable.""" + + qml.capture.enable() + + # Capture in qnode level + @qjit(static_argnums=(0,)) + @qml.qnode(qml.device(backend, wires=1)) + def captured_circuit_1(x, y): + qml.RX(x, wires=0) + qml.RY(y, wires=0) + return qml.expval(qml.PauliZ(0)) + + # Ignore static_argnums in the qnode + @qjit(static_argnums=1) + @qml.qnode(qml.device(backend, wires=1), static_argnums=0) + def captured_circuit_2(x, y): + qml.RX(x, wires=0) + qml.RY(y, wires=0) + return qml.expval(qml.PauliZ(0)) + + result_1 = captured_circuit_1(1.5, 2.0) + assert "stablehlo.constant dense<1.500000e+00>" in captured_circuit_1.mlir + assert "stablehlo.constant dense<2.000000e+00>" not in captured_circuit_1.mlir + + result_2 = captured_circuit_2(1.5, 2.0) + assert "stablehlo.constant dense<1.500000e+00>" not in captured_circuit_2.mlir + assert "stablehlo.constant dense<2.000000e+00>" in captured_circuit_2.mlir + + assert result_1 == result_2 + + qml.capture.disable() From 9ae5eb5a18e84b0a6906042df7000d8f3c6e5966 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 13:44:57 -0400 Subject: [PATCH 04/10] cache mlir string in test to avoid recompilation --- frontend/test/pytest/test_capture_integration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/frontend/test/pytest/test_capture_integration.py b/frontend/test/pytest/test_capture_integration.py index 248216283..b6ee40ebc 100644 --- a/frontend/test/pytest/test_capture_integration.py +++ b/frontend/test/pytest/test_capture_integration.py @@ -1523,12 +1523,14 @@ def captured_circuit_2(x, y): return qml.expval(qml.PauliZ(0)) result_1 = captured_circuit_1(1.5, 2.0) - assert "stablehlo.constant dense<1.500000e+00>" in captured_circuit_1.mlir - assert "stablehlo.constant dense<2.000000e+00>" not in captured_circuit_1.mlir + captured_circuit_1_mlir = captured_circuit_1.mlir + assert "stablehlo.constant dense<1.500000e+00>" in captured_circuit_1_mlir + assert "stablehlo.constant dense<2.000000e+00>" not in captured_circuit_1_mlir result_2 = captured_circuit_2(1.5, 2.0) - assert "stablehlo.constant dense<1.500000e+00>" not in captured_circuit_2.mlir - assert "stablehlo.constant dense<2.000000e+00>" in captured_circuit_2.mlir + captured_circuit_2_mlir = captured_circuit_2.mlir + assert "stablehlo.constant dense<1.500000e+00>" not in captured_circuit_2_mlir + assert "stablehlo.constant dense<2.000000e+00>" in captured_circuit_2_mlir assert result_1 == result_2 From b9ed4bf7ce1d93af75723b2665f54837fe4e4f4b Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 14:05:47 -0400 Subject: [PATCH 05/10] changelog --- doc/releases/changelog-dev.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6ee1b7d50..197b2f28d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -127,6 +127,9 @@ performance by eliminating indirect conversion. [(#1738)](https://github.com/PennyLaneAI/catalyst/pull/1738) +* `static_argnums` on `qjit` can now be specified with program capture through PLxPR. + [(#1810)](https://github.com/PennyLaneAI/catalyst/pull/1810) +

Breaking changes 💔

* (Device Developers Only) The `QuantumDevice` interface in the Catalyst Runtime plugin system From bc910248769ac7d3ebd16ff6d534308c852bc305 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 14:23:38 -0400 Subject: [PATCH 06/10] docstring --- frontend/catalyst/from_plxpr/from_plxpr.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index dcedfb3b1..c519afb58 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -499,10 +499,27 @@ def trace_from_pennylane( PL capure module. Args: - args (Iterable): arguments to use for program capture + fn(Callable): the user function to be traced + static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the + positions of static arguments. + dynamic_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the + positions of dynamic arguments. + abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): + An experimental option to specify dynamic tensor shapes. + This option affects the compilation of the annotated function. + Function arguments with ``abstracted_axes`` specified will be compiled to ranked tensors + with dynamic shapes. For more details, please see the Dynamically-shaped Arrays section + below. + sig(Sequence[Any]): a tuple indicating the argument signature of the function. Static arguments + are indicated with their literal values, and dynamic arguments are indicated by abstract + values. + kwargs(Dict[str, Any]): keyword argumemts to the function. + debug_info(jax.api_util.debug_info): a source debug information object required by jaxprs. Returns: ClosedJaxpr: captured JAXPR + Tuple[Tuple[ShapedArray, bool]]: the return type of the captured JAXPR. + The boolean indicates whether each result is a value returned by the user function. PyTreeDef: PyTree metadata of the function output Tuple[Any]: the dynamic argument signature """ From d44c60903736defcc1e24ec0dff76569cf1a72a6 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 16:36:08 -0400 Subject: [PATCH 07/10] overwrite qnode static argnum with qjit static argnum --- frontend/catalyst/from_plxpr/from_plxpr.py | 15 +++++++++---- .../test/pytest/test_capture_integration.py | 22 ++++++++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index c519afb58..d9fe44054 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -493,7 +493,7 @@ def handle_measure_in_basis(self, angle, wire, plane, reset, postselect): # pylint: disable=too-many-positional-arguments def trace_from_pennylane( - fn, static_argnums, dynamic_argnums, abstracted_axes, sig, kwargs, debug_info=None + fn, static_argnums, dynamic_args, abstracted_axes, sig, kwargs, debug_info=None ): """Capture the JAX program representation (JAXPR) of the wrapped function, using PL capure module. @@ -502,8 +502,7 @@ def trace_from_pennylane( fn(Callable): the user function to be traced static_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the positions of static arguments. - dynamic_argnums(int or Seqence[Int]): an index or a sequence of indices that specifies the - positions of dynamic arguments. + dynamic_args(Seqence[Any]): the abstract values of the dynamic arguments. abstracted_axes (Sequence[Sequence[str]] or Dict[int, str] or Sequence[Dict[int, str]]): An experimental option to specify dynamic tensor shapes. This option affects the compilation of the annotated function. @@ -534,7 +533,15 @@ def trace_from_pennylane( args = sig + if isinstance(fn, qml.QNode) and static_argnums: + # `make_jaxpr2` sees the qnode + # The static_argnum on the wrapped function takes precedence over the + # one in `make_jaxpr` + # https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231 + # Therefore we need to coordinate them manually + fn.static_argnums = static_argnums + plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs) - jaxpr = from_plxpr(plxpr)(*dynamic_argnums, **kwargs) + jaxpr = from_plxpr(plxpr)(*dynamic_args, **kwargs) return jaxpr, out_type, out_treedef, sig diff --git a/frontend/test/pytest/test_capture_integration.py b/frontend/test/pytest/test_capture_integration.py index b6ee40ebc..6945c9f5a 100644 --- a/frontend/test/pytest/test_capture_integration.py +++ b/frontend/test/pytest/test_capture_integration.py @@ -1506,7 +1506,7 @@ def test_static_variable_qnode(self, backend): qml.capture.enable() - # Capture in qnode level + # Basic test @qjit(static_argnums=(0,)) @qml.qnode(qml.device(backend, wires=1)) def captured_circuit_1(x, y): @@ -1514,23 +1514,25 @@ def captured_circuit_1(x, y): qml.RY(y, wires=0) return qml.expval(qml.PauliZ(0)) - # Ignore static_argnums in the qnode + result_1 = captured_circuit_1(1.5, 2.0) + captured_circuit_1_mlir = captured_circuit_1.mlir + assert "%cst = arith.constant 1.5" in captured_circuit_1_mlir + assert 'quantum.custom "RX"(%cst)' in captured_circuit_1_mlir + assert "%cst = arith.constant 2.0" not in captured_circuit_1_mlir + + # Test that qjit static_argnums takes precedence over the one on the qnode @qjit(static_argnums=1) - @qml.qnode(qml.device(backend, wires=1), static_argnums=0) + @qml.qnode(qml.device(backend, wires=1), static_argnums=0) # should be ignored def captured_circuit_2(x, y): qml.RX(x, wires=0) qml.RY(y, wires=0) return qml.expval(qml.PauliZ(0)) - result_1 = captured_circuit_1(1.5, 2.0) - captured_circuit_1_mlir = captured_circuit_1.mlir - assert "stablehlo.constant dense<1.500000e+00>" in captured_circuit_1_mlir - assert "stablehlo.constant dense<2.000000e+00>" not in captured_circuit_1_mlir - result_2 = captured_circuit_2(1.5, 2.0) captured_circuit_2_mlir = captured_circuit_2.mlir - assert "stablehlo.constant dense<1.500000e+00>" not in captured_circuit_2_mlir - assert "stablehlo.constant dense<2.000000e+00>" in captured_circuit_2_mlir + assert "%cst = arith.constant 2.0" in captured_circuit_2_mlir + assert 'quantum.custom "RY"(%cst)' in captured_circuit_2_mlir + assert "%cst = arith.constant 1.5" not in captured_circuit_1_mlir assert result_1 == result_2 From 8d4786675dafbfda8b35a6d0e053ffa6da20d47c Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 16:37:14 -0400 Subject: [PATCH 08/10] typo --- frontend/test/pytest/test_capture_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_capture_integration.py b/frontend/test/pytest/test_capture_integration.py index 6945c9f5a..6ec3e7fc2 100644 --- a/frontend/test/pytest/test_capture_integration.py +++ b/frontend/test/pytest/test_capture_integration.py @@ -1532,7 +1532,7 @@ def captured_circuit_2(x, y): captured_circuit_2_mlir = captured_circuit_2.mlir assert "%cst = arith.constant 2.0" in captured_circuit_2_mlir assert 'quantum.custom "RY"(%cst)' in captured_circuit_2_mlir - assert "%cst = arith.constant 1.5" not in captured_circuit_1_mlir + assert "%cst = arith.constant 1.5" not in captured_circuit_2_mlir assert result_1 == result_2 From fd20f16ced755e8be6ed601ae2468db9b353d020 Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 17:01:12 -0400 Subject: [PATCH 09/10] add test for non-qnode function --- frontend/test/pytest/test_capture_integration.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/frontend/test/pytest/test_capture_integration.py b/frontend/test/pytest/test_capture_integration.py index 6ec3e7fc2..38e573b1f 100644 --- a/frontend/test/pytest/test_capture_integration.py +++ b/frontend/test/pytest/test_capture_integration.py @@ -1536,4 +1536,20 @@ def captured_circuit_2(x, y): assert result_1 == result_2 + # Test under a non qnode workflow function + @qjit(static_argnums=(0,)) + def workflow(x, y): + @qml.qnode(qml.device(backend, wires=1)) + def c(): + qml.RX(x, wires=0) + qml.RY(y, wires=0) + return qml.expval(qml.PauliZ(0)) + + return c() + + result_3 = workflow(1.5, 2.0) + captured_circuit_3_mlir = workflow.mlir + assert "%cst = arith.constant 1.5" in captured_circuit_3_mlir + assert 'quantum.custom "RX"(%cst)' in captured_circuit_3_mlir + qml.capture.disable() From 1cd4888d2158083fe94bbda144fcff0a1512d46f Mon Sep 17 00:00:00 2001 From: paul0403 Date: Fri, 20 Jun 2025 17:10:07 -0400 Subject: [PATCH 10/10] codefactor --- frontend/test/pytest/test_capture_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_capture_integration.py b/frontend/test/pytest/test_capture_integration.py index 38e573b1f..35f9f1f58 100644 --- a/frontend/test/pytest/test_capture_integration.py +++ b/frontend/test/pytest/test_capture_integration.py @@ -1547,7 +1547,7 @@ def c(): return c() - result_3 = workflow(1.5, 2.0) + _ = workflow(1.5, 2.0) captured_circuit_3_mlir = workflow.mlir assert "%cst = arith.constant 1.5" in captured_circuit_3_mlir assert 'quantum.custom "RX"(%cst)' in captured_circuit_3_mlir