diff --git a/.dep-versions b/.dep-versions index edc108195..6ed4ee1ae 100644 --- a/.dep-versions +++ b/.dep-versions @@ -16,7 +16,7 @@ enzyme=v0.0.149 # For a custom PL version, update the package version here and at # 'doc/requirements.txt -pennylane=0.42.0-dev33 +pennylane=0.42.0-dev45 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 71da0334e..9775d84b3 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -1328,6 +1328,31 @@ def trace_quantum_function( out_tree: PyTree shapen of the result """ + if qml.transforms.set_shots in qnode.transform_program: + # Then just apply it immediately + user_transform = qnode.transform_program + set_shots_transform = TransformProgram() + + # Find and extract the set_shots transform + for transform_container in user_transform: + if transform_container.transform == qml.set_shots.transform: + set_shots_transform = transform_container + break + + if set_shots_transform: + # Apply set_shots transform to update device shots + shots_value = set_shots_transform.kwargs.get("shots", None) + # Update device shots + if hasattr(device, "shots"): + if isinstance(device, qml.devices.LegacyDevice): + device._shots = shots_value + else: + device._shots = qml.measurements.Shots(shots_value) + + # # Remove set_shots from the transform program since we've applied it + # user_transform.remove(set_shots_transform) + # qnode_program = user_transform + with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: # (1) - Classical tracing quantum_tape = QuantumTape(shots=device.shots) diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index b8b760080..bc695d233 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -13,6 +13,7 @@ # limitations under the License. """Test for the device API.""" import platform +from functools import partial import pennylane as qml import pytest @@ -137,6 +138,21 @@ def circuit(): assert circuit.mlir +def test_simple_circuit_set_shots(): + """Test that a circuit with the new device API is compiling to MLIR.""" + dev = NullQubit(wires=2) + + @qjit(target="mlir") + @partial(qml.set_shots, shots=2048) + @qml.qnode(device=dev) + def circuit(): + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliZ(wires=0)) + + assert circuit.mlir + + def test_track_resources(): """Test that resource tracking settings get passed to the device.""" dev = NullQubit(wires=2)