From f0227a56d6e8ce6f13a09a117ecbf00d2f3d746e Mon Sep 17 00:00:00 2001 From: JerryChen97 Date: Tue, 3 Jun 2025 17:28:51 -0400 Subject: [PATCH 1/6] test --- frontend/catalyst/jax_tracer.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 71da0334e..e1796c452 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1327,6 +1327,31 @@ def trace_quantum_function( out_type: JAXPR output type (list of abstract values with explicitness flags). out_tree: PyTree shapen of the result """ + + if qml.transforms.set_shots in qnode.transform_program: + # Then just apply it immediately + user_transform = qnode.transform_program.copy() + set_shots_transform = TransformProgram() + + # Find and extract the set_shots transform + for transform_container in user_transform: + if transform_container.transform == qml.set_shots.transform: + set_shots_transform = transform_container + break + + if set_shots_transform: + # Apply set_shots transform to update device shots + shots_value = set_shots_transform.kwargs["shots"] if hasattr(set_shots_transform, "kwargs") else None + # Update device shots + if hasattr(device, 'shots'): + if isinstance(device, qml.devices.LegacyDevice): + device._shots = shots_value + else: + device._shots = qml.measurements.Shots(shots_value) + + # Remove set_shots from the transform program since we've applied it + user_transform.remove(set_shots_transform) + qnode_program = user_transform with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: # (1) - Classical tracing From 8f4ef37c6a39244c5a74ed45f829ff583cc1811a Mon Sep 17 00:00:00 2001 From: JerryChen97 Date: Wed, 4 Jun 2025 10:14:53 -0400 Subject: [PATCH 2/6] concise --- frontend/catalyst/jax_tracer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index e1796c452..98184b4de 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1341,7 +1341,7 @@ def trace_quantum_function( if set_shots_transform: # Apply set_shots transform to update device shots - shots_value = set_shots_transform.kwargs["shots"] if hasattr(set_shots_transform, "kwargs") else None + shots_value = set_shots_transform.kwargs.get('shots', None) # Update device shots if hasattr(device, 'shots'): if isinstance(device, qml.devices.LegacyDevice): From e2a51ce5038cc624bfed3901a5ea0ca9382dc96c Mon Sep 17 00:00:00 2001 From: JerryChen97 Date: Wed, 4 Jun 2025 15:31:40 -0400 Subject: [PATCH 3/6] black --- frontend/catalyst/jax_tracer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 98184b4de..df6732f44 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1327,28 +1327,28 @@ def trace_quantum_function( out_type: JAXPR output type (list of abstract values with explicitness flags). out_tree: PyTree shapen of the result """ - + if qml.transforms.set_shots in qnode.transform_program: # Then just apply it immediately user_transform = qnode.transform_program.copy() set_shots_transform = TransformProgram() - + # Find and extract the set_shots transform for transform_container in user_transform: if transform_container.transform == qml.set_shots.transform: set_shots_transform = transform_container break - + if set_shots_transform: # Apply set_shots transform to update device shots - shots_value = set_shots_transform.kwargs.get('shots', None) + shots_value = set_shots_transform.kwargs.get("shots", None) # Update device shots - if hasattr(device, 'shots'): + if hasattr(device, "shots"): if isinstance(device, qml.devices.LegacyDevice): device._shots = shots_value else: device._shots = qml.measurements.Shots(shots_value) - + # Remove set_shots from the transform program since we've applied it user_transform.remove(set_shots_transform) qnode_program = user_transform From 6d99edede37ba1ec88d35aa965723414b1134bc9 Mon Sep 17 00:00:00 2001 From: JerryChen97 Date: Wed, 4 Jun 2025 16:08:19 -0400 Subject: [PATCH 4/6] add a test --- frontend/test/pytest/test_device_api.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index b8b760080..bc695d233 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -13,6 +13,7 @@ # limitations under the License. """Test for the device API.""" import platform +from functools import partial import pennylane as qml import pytest @@ -137,6 +138,21 @@ def circuit(): assert circuit.mlir +def test_simple_circuit_set_shots(): + """Test that a circuit with the new device API is compiling to MLIR.""" + dev = NullQubit(wires=2) + + @qjit(target="mlir") + @partial(qml.set_shots, shots=2048) + @qml.qnode(device=dev) + def circuit(): + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(wires=0)) + + assert circuit.mlir + + def test_track_resources(): """Test that resource tracking settings get passed to the device.""" dev = NullQubit(wires=2) From d241e2e5ca0ab3e39ecc58a5455542269d3c34ba Mon Sep 17 00:00:00 2001 From: JerryChen97 Date: Wed, 4 Jun 2025 16:55:24 -0400 Subject: [PATCH 5/6] bump dev --- .dep-versions | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dep-versions b/.dep-versions index edc108195..6ed4ee1ae 100644 --- a/.dep-versions +++ b/.dep-versions @@ -16,7 +16,7 @@ enzyme=v0.0.149 # For a custom PL version, update the package version here and at # 'doc/requirements.txt -pennylane=0.42.0-dev33 +pennylane=0.42.0-dev45 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' From 947ae5fefcf53e45dc1dc9d0f6cc6cb20e1e3ccf Mon Sep 17 00:00:00 2001 From: JerryChen97 Date: Wed, 4 Jun 2025 17:22:03 -0400 Subject: [PATCH 6/6] debug --- frontend/catalyst/jax_tracer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index df6732f44..9775d84b3 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1330,7 +1330,7 @@ def trace_quantum_function( if qml.transforms.set_shots in qnode.transform_program: # Then just apply it immediately - user_transform = qnode.transform_program.copy() + user_transform = qnode.transform_program set_shots_transform = TransformProgram() # Find and extract the set_shots transform @@ -1349,9 +1349,9 @@ def trace_quantum_function( else: device._shots = qml.measurements.Shots(shots_value) - # Remove set_shots from the transform program since we've applied it - user_transform.remove(set_shots_transform) - qnode_program = user_transform + # # Remove set_shots from the transform program since we've applied it + # user_transform.remove(set_shots_transform) + # qnode_program = user_transform with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: # (1) - Classical tracing