From a161bb7c0128b32cbd7b2cad3512e123b27e90fd Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Fri, 30 May 2025 16:25:20 +0000 Subject: [PATCH 01/16] make aot compilation more general --- doc/releases/changelog-dev.md | 3 +++ frontend/catalyst/tracing/type_signatures.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f486fb6824..f4ea8bbe4a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -249,6 +249,9 @@ This runtime stub is currently for mock execution only and should be treated as a placeholder operation. Internally, it functions just as a computational-basis measurement instruction. +* Quantum subroutine stub. + [(#1774)](https://github.com/PennyLaneAI/catalyst/pull/1774) + * PennyLane's arbitrary-basis measurement operations, such as :func:`qml.ftqc.measure_arbitrary_basis() `, are now QJIT-compatible with program capture enabled. diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 137154a9e6..b4090001c4 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -25,6 +25,7 @@ from jax._src.core import shaped_abstractify from jax._src.interpreters.partial_eval import infer_lambda_input_type from jax._src.pjit import _flat_axes_specs +from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten from catalyst.jax_extras import get_aval2 @@ -56,7 +57,7 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, jax.core.ShapedArray)) for annotation in annotations) + return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations) def get_type_annotations(fn: Callable): From 37e7c8772bf93812ee9e9b10b50ee3da63cb542e Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Fri, 30 May 2025 16:44:50 +0000 Subject: [PATCH 02/16] Add pytype-aval mapping and fix mlir-lowering --- frontend/catalyst/jax_primitives.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 31bc822203..ab1d001bd8 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -138,13 +138,13 @@ def __hash__(self): # pragma: nocover return self.hash_value -class ConcreteQbit(AbstractQbit): +class ConcreteQbit: """Concrete Qbit.""" def _qbit_lowering(aval): assert isinstance(aval, AbstractQbit) - return (ir.OpaqueType.get("quantum", "bit"),) + return ir.OpaqueType.get("quantum", "bit") # @@ -162,13 +162,13 @@ def __hash__(self): return self.hash_value -class ConcreteQreg(AbstractQreg): +class ConcreteQreg: """Concrete quantum register.""" def _qreg_lowering(aval): assert isinstance(aval, AbstractQreg) - return (ir.OpaqueType.get("quantum", "reg"),) + return ir.OpaqueType.get("quantum", "reg") # @@ -2357,3 +2357,5 @@ def _scalar_abstractify(t): pytype_aval_mappings[type] = _scalar_abstractify pytype_aval_mappings[jax._src.numpy.scalar_types._ScalarMeta] = _scalar_abstractify +pytype_aval_mappings[ConcreteQbit] = lambda _: AbstractQbit() +pytype_aval_mappings[ConcreteQreg] = lambda _: AbstractQreg() From b3d5c43e4eaa5b90a47d6ef7c46e0128201dbce1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Mon, 2 Jun 2025 13:01:36 +0000 Subject: [PATCH 03/16] wip --- frontend/catalyst/from_plxpr.py | 6 ++++++ frontend/catalyst/jax_extras/lowering.py | 1 + frontend/catalyst/jax_primitives.py | 10 ++-------- frontend/catalyst/jax_tracer.py | 22 +++++++++++++++++----- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index ae3c67f730..d44904bd17 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -424,6 +424,12 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 +from jax._src.pjit import pjit_p + +@QFuncPlxprInterpreter.register_primitive(pjit_p) +def handle_pjit_p(self, *args, **kwargs): + breakpoint() + @QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index d2bfbc74bb..6796cca787 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -140,6 +140,7 @@ def custom_lower_jaxpr_to_module( # XLA computation preserves the module name. module_name = _module_name_regex.sub("_", module_name) ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) + breakpoint() lower_jaxpr_to_fun( ctx, func_name, diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index ab1d001bd8..e37aee5faa 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -153,14 +153,8 @@ def _qbit_lowering(aval): class AbstractQreg(AbstractValue): """Abstract quantum register.""" - hash_value = hash("AbstractQreg") - - def __eq__(self, other): - return isinstance(other, AbstractQreg) - - def __hash__(self): - return self.hash_value - + def _add(self, left, right): + return AbstractQreg() class ConcreteQreg: """Concrete quantum register.""" diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 71da0334ee..739a0ca5a6 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -425,6 +425,7 @@ class HybridOpRegion: cached_vars = weakref.WeakKeyDictionary() + class HybridOp(Operator): """A base class for operations carrying nested regions. The class stores the information obtained in the process of classical tracing and required for the completion of the quantum @@ -1302,6 +1303,16 @@ def trace_function( return res_expanded_tracers, in_sig, out_sig +GLOBAL_QREG = None + +def get_qreg(): + global GLOBAL_QREG + return GLOBAL_QREG + +def set_qreg(qreg): + global GLOBAL_QREG + GLOBAL_QREG = qreg + @debug_logger def trace_quantum_function( @@ -1371,6 +1382,7 @@ def is_leaf(obj): return_values_flat, ) + global GLOBAL_QREG # (2) - Quantum tracing transformed_results = [] with EvaluationContext.frame_tracing_context(trace): @@ -1392,9 +1404,9 @@ def is_leaf(obj): if catalyst.device.qjit_device.is_dynamic_wires(device.wires): # When device has dynamic wires, the device.wires iterable object # has a single value, which is the tracer for the number of wires - qreg_in = qalloc_p.bind(device.wires[0]) + GLOBAL_QREG = qalloc_p.bind(device.wires[0]) else: - qreg_in = qalloc_p.bind(len(device.wires)) + GLOBAL_QREG = qalloc_p.bind(len(device.wires)) # If the program is batched, that means that it was transformed. # If it was transformed, that means that the program might have @@ -1413,10 +1425,10 @@ def is_leaf(obj): ) snapshot_results = [] qrp_out = trace_quantum_operations( - tape, device, qreg_in, ctx, trace, mcm_config, snapshot_results + tape, device, GLOBAL_QREG, ctx, trace, mcm_config, snapshot_results ) meas, meas_trees = trace_quantum_measurements(device, qrp_out, output, trees) - qreg_out = qrp_out.actualize() + GLOBAL_QREG = qrp_out.actualize() # Check if the measurements are nested then apply the to_jaxpr_tracer def check_full_raise(arr, func): @@ -1448,7 +1460,7 @@ def check_full_raise(arr, func): transformed_results.append(meas_results) # Deallocate the register and release the device after the current tape is finished. - qdealloc_p.bind(qreg_out) + qdealloc_p.bind(GLOBAL_QREG) device_release_p.bind() closed_jaxpr, out_type, out_tree = trace_post_processing( From 21c4ce2f82e533a42105ce19d6c4567ef94e58b8 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Mon, 2 Jun 2025 15:40:48 +0000 Subject: [PATCH 04/16] Revert "wip" This reverts commit 32ee218632266034ab44da26034c7d86e1ab1bbc. --- frontend/catalyst/from_plxpr.py | 6 ------ frontend/catalyst/jax_extras/lowering.py | 1 - frontend/catalyst/jax_primitives.py | 10 ++++++++-- frontend/catalyst/jax_tracer.py | 22 +++++----------------- 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index d44904bd17..ae3c67f730 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -424,12 +424,6 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 -from jax._src.pjit import pjit_p - -@QFuncPlxprInterpreter.register_primitive(pjit_p) -def handle_pjit_p(self, *args, **kwargs): - breakpoint() - @QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 6796cca787..d2bfbc74bb 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -140,7 +140,6 @@ def custom_lower_jaxpr_to_module( # XLA computation preserves the module name. module_name = _module_name_regex.sub("_", module_name) ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) - breakpoint() lower_jaxpr_to_fun( ctx, func_name, diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index e37aee5faa..ab1d001bd8 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -153,8 +153,14 @@ def _qbit_lowering(aval): class AbstractQreg(AbstractValue): """Abstract quantum register.""" - def _add(self, left, right): - return AbstractQreg() + hash_value = hash("AbstractQreg") + + def __eq__(self, other): + return isinstance(other, AbstractQreg) + + def __hash__(self): + return self.hash_value + class ConcreteQreg: """Concrete quantum register.""" diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 739a0ca5a6..71da0334ee 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -425,7 +425,6 @@ class HybridOpRegion: cached_vars = weakref.WeakKeyDictionary() - class HybridOp(Operator): """A base class for operations carrying nested regions. The class stores the information obtained in the process of classical tracing and required for the completion of the quantum @@ -1303,16 +1302,6 @@ def trace_function( return res_expanded_tracers, in_sig, out_sig -GLOBAL_QREG = None - -def get_qreg(): - global GLOBAL_QREG - return GLOBAL_QREG - -def set_qreg(qreg): - global GLOBAL_QREG - GLOBAL_QREG = qreg - @debug_logger def trace_quantum_function( @@ -1382,7 +1371,6 @@ def is_leaf(obj): return_values_flat, ) - global GLOBAL_QREG # (2) - Quantum tracing transformed_results = [] with EvaluationContext.frame_tracing_context(trace): @@ -1404,9 +1392,9 @@ def is_leaf(obj): if catalyst.device.qjit_device.is_dynamic_wires(device.wires): # When device has dynamic wires, the device.wires iterable object # has a single value, which is the tracer for the number of wires - GLOBAL_QREG = qalloc_p.bind(device.wires[0]) + qreg_in = qalloc_p.bind(device.wires[0]) else: - GLOBAL_QREG = qalloc_p.bind(len(device.wires)) + qreg_in = qalloc_p.bind(len(device.wires)) # If the program is batched, that means that it was transformed. # If it was transformed, that means that the program might have @@ -1425,10 +1413,10 @@ def is_leaf(obj): ) snapshot_results = [] qrp_out = trace_quantum_operations( - tape, device, GLOBAL_QREG, ctx, trace, mcm_config, snapshot_results + tape, device, qreg_in, ctx, trace, mcm_config, snapshot_results ) meas, meas_trees = trace_quantum_measurements(device, qrp_out, output, trees) - GLOBAL_QREG = qrp_out.actualize() + qreg_out = qrp_out.actualize() # Check if the measurements are nested then apply the to_jaxpr_tracer def check_full_raise(arr, func): @@ -1460,7 +1448,7 @@ def check_full_raise(arr, func): transformed_results.append(meas_results) # Deallocate the register and release the device after the current tape is finished. - qdealloc_p.bind(GLOBAL_QREG) + qdealloc_p.bind(qreg_out) device_release_p.bind() closed_jaxpr, out_type, out_tree = trace_post_processing( From ffdd11a97b93380e3566e8f377a0c643e4ac6663 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Mon, 2 Jun 2025 18:37:28 +0000 Subject: [PATCH 05/16] I can add qreg as a parameter and return it --- frontend/catalyst/from_plxpr.py | 45 +++++++++++++++++++++++++++++ frontend/catalyst/jax_primitives.py | 27 +++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index ae3c67f730..6e7a95f571 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -68,6 +68,7 @@ qinsert_p, qinst_p, quantum_kernel_p, + quantum_subroutine_p, sample_p, set_basis_state_p, set_state_p, @@ -425,6 +426,50 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 + +@QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) +def handle_subroutine(self, *args, **kwargs): + # We need to pass the wire as an argument... + # And we somehow need to start another interpreter + # but only in case it is not yet already available... + from jax.experimental.pjit import pjit_p + from catalyst.jax_primitives import AbstractQreg + from jax._src.core import jaxpr_as_fun + + backup = {orig_wire: wire for orig_wire, wire in self.wire_map.items()} + self.actualize_qreg() + + def wrapper(qreg, *args): + retval = jaxpr_as_fun(kwargs["jaxpr"], *args)() + return qreg, retval + + jaxpr = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) + # So, what I need to do here is transform this jaxpr + # With `from_plxpr` to but we need to make sure that + # the first argument is treated as the qreg... + + + vals_out = quantum_subroutine_p.bind(self.qreg, *args, + jaxpr=jaxpr, + in_shardings=kwargs["in_shardings"], + out_shardings=kwargs["out_shardings"], + in_layouts=kwargs["in_layouts"], + out_layouts=kwargs["out_layouts"], + donated_invars=kwargs["donated_invars"], + ctx_mesh=kwargs["ctx_mesh"], + name=kwargs["name"], + keep_unused=kwargs["keep_unused"], + inline=kwargs["inline"], + compiler_options_kvs=kwargs["compiler_options_kvs"]) + + self.qjit = vals_out[0] + vals_out = vals_out[1:] + + for orig_wire, _ in backup: + self.wire_map[orig_wire] = qextract_p.bind(self.qreg, orig_wire) + + return vals_out + @QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the QubitUnitary primitive""" diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index ab1d001bd8..44edad6eb8 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -18,6 +18,7 @@ import sys from dataclasses import dataclass from enum import Enum +import functools from itertools import chain from typing import Iterable, List, Union @@ -305,6 +306,27 @@ class MeasurementPlane(Enum): measure_in_basis_p = Primitive("measure_in_basis") measure_in_basis_p.multiple_results = True +from jax.experimental.pjit import pjit_p +from jax._src.pjit import _pjit_lowering +import copy +from catalyst.utils.patching import Patcher +quantum_subroutine_p = copy.deepcopy(pjit_p) +quantum_subroutine_p.name = "quantum_subroutine_p" + +def subroutine(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with Patcher( + ( + jax._src.pjit, + "pjit_p", + quantum_subroutine_p, + ), + ): + return jax.jit(func)(*args, **kwargs) + return wrapper + def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr): assert len(jaxpr.consts) == 0, ( @@ -2303,6 +2325,10 @@ def _cos_lowering2(ctx, x, accuracy): """Use hlo.cosine lowering instead of the new cosine lowering from jax 0.4.28""" return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) +def subroutine_lowering(*args, **kwargs): + breakpoint() + retval = _pjit_lowering(*args, **kwargs) + return retval CUSTOM_LOWERING_RULES = ( (zne_p, _zne_lowering), @@ -2344,6 +2370,7 @@ def _cos_lowering2(ctx, x, accuracy): (sin_p, _sin_lowering2), (cos_p, _cos_lowering2), (quantum_kernel_p, _quantum_kernel_lowering), + (quantum_subroutine_p, subroutine_lowering), (measure_in_basis_p, _measure_in_basis_lowering), ) From 817cf1461a67ff3a7083153f0095c41610549467 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Mon, 2 Jun 2025 20:12:09 +0000 Subject: [PATCH 06/16] wip --- frontend/catalyst/from_plxpr.py | 74 ++++++++++++++++------------- frontend/catalyst/jax_primitives.py | 1 - 2 files changed, 42 insertions(+), 33 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 6e7a95f571..5f38de6324 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -167,6 +167,9 @@ def f(x): """ return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) +def from_subroutine(jaxpr): + return jax.make_jaxpr(partial(SubroutineInterpreter().eval, jaxpr.jaxpr, jaxpr.consts)) + class WorkflowInterpreter(PlxprInterpreter): """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" @@ -267,19 +270,8 @@ def wrapper(*args): for pl_transform, (pass_name, decomposition) in transforms_to_passes.items(): register_transform(pl_transform, pass_name, decomposition) - -class QFuncPlxprInterpreter(PlxprInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) - - """ - - def __init__(self, device, shots: qml.measurements.Shots | int): - self._device = device - self._shots = self._extract_shots_value(shots) +class SubroutineInterpreter(PlxprInterpreter): + def __init__(self): self.stateref = None self.actualized = False super().__init__() @@ -299,25 +291,6 @@ def __setattr__(self, __name: str, __value) -> None: else: super().__setattr__(__name, __value) - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - def get_wire(self, wire_value) -> AbstractQbit: """Get the ``AbstractQbit`` corresponding to a wire value.""" if wire_value in self.wire_map: @@ -426,12 +399,46 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 +class QFuncPlxprInterpreter(SubroutineInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def __init__(self, device, shots: qml.measurements.Shots | int): + self._device = device + self._shots = self._extract_shots_value(shots) + super().__init__() + + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + @QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): # We need to pass the wire as an argument... # And we somehow need to start another interpreter # but only in case it is not yet already available... + raise NotImplementedError() from jax.experimental.pjit import pjit_p from catalyst.jax_primitives import AbstractQreg from jax._src.core import jaxpr_as_fun @@ -444,11 +451,14 @@ def wrapper(qreg, *args): return qreg, retval jaxpr = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) + #jaxpr = from_plxpr(jaxpr)(AbstractQreg(), *args) # So, what I need to do here is transform this jaxpr # With `from_plxpr` to but we need to make sure that # the first argument is treated as the qreg... + # quantum_subroutine_p.bind + # is just pjit_p with a different name. vals_out = quantum_subroutine_p.bind(self.qreg, *args, jaxpr=jaxpr, in_shardings=kwargs["in_shardings"], diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 44edad6eb8..4095199738 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -2326,7 +2326,6 @@ def _cos_lowering2(ctx, x, accuracy): return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) def subroutine_lowering(*args, **kwargs): - breakpoint() retval = _pjit_lowering(*args, **kwargs) return retval From 42c062d6a23d9fe6254af36347ad5dca779af3c2 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Tue, 3 Jun 2025 18:44:20 +0000 Subject: [PATCH 07/16] wip --- frontend/catalyst/from_plxpr.py | 124 ++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 55 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 5f38de6324..0ed5c2a51b 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -188,7 +188,7 @@ def handle_qnode( consts = args[:n_consts] non_const_args = args[n_consts:] - f = partial(QFuncPlxprInterpreter(device, shots).eval, qfunc_jaxpr, consts) + f = partial(QFuncPlxprInterpreter(device, shots, None, True).eval, qfunc_jaxpr, consts) return quantum_kernel_p.bind( wrap_init(f, debug_info=qfunc_jaxpr.debug_info), @@ -271,9 +271,11 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) class SubroutineInterpreter(PlxprInterpreter): - def __init__(self): - self.stateref = None - self.actualized = False + def __init__(self, device, shots, stateref, actualized): + self._device = device + self._shots = self._extract_shots_value(shots) + self.stateref = stateref + self.actualized = actualized super().__init__() def __getattr__(self, key): @@ -363,9 +365,11 @@ def interpret_measurement(self, measurement): else: obs = self._compbasis_obs(*measurement.wires) + shots = self._device.shots.total_shots + shape, dtype = measurement._abstract_eval( n_wires=len(measurement.wires), - shots=self._device.shots.total_shots, + shots=shots, num_device_wires=len(self._device.wires), ) @@ -399,46 +403,12 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 -class QFuncPlxprInterpreter(SubroutineInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) - - """ - - def __init__(self, device, shots: qml.measurements.Shots | int): - self._device = device - self._shots = self._extract_shots_value(shots) - super().__init__() - - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - -@QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) +@SubroutineInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): # We need to pass the wire as an argument... # And we somehow need to start another interpreter # but only in case it is not yet already available... - raise NotImplementedError() from jax.experimental.pjit import pjit_p from catalyst.jax_primitives import AbstractQreg from jax._src.core import jaxpr_as_fun @@ -480,7 +450,7 @@ def wrapper(qreg, *args): return vals_out -@QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) +@SubroutineInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the QubitUnitary primitive""" wires = [self.get_wire(w) for w in invals[1:]] @@ -490,13 +460,13 @@ def handle_qubit_unitary(self, *invals, n_wires): # pylint: disable=unused-argument -@QFuncPlxprInterpreter.register_primitive(qml.GlobalPhase._primitive) +@SubroutineInterpreter.register_primitive(qml.GlobalPhase._primitive) def handle_global_phase(self, phase, *wires, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the GlobalPhase primitive""" gphase_p.bind(phase, ctrl_len=0, adjoint=False) -@QFuncPlxprInterpreter.register_primitive(qml.BasisState._primitive) +@SubroutineInterpreter.register_primitive(qml.BasisState._primitive) def handle_basis_state(self, *invals, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the BasisState primitive""" state_inval = invals[0] @@ -511,7 +481,7 @@ def handle_basis_state(self, *invals, n_wires): # pylint: disable=unused-argument -@QFuncPlxprInterpreter.register_primitive(qml.StatePrep._primitive) +@SubroutineInterpreter.register_primitive(qml.StatePrep._primitive) def handle_state_prep(self, *invals, n_wires, **kwargs): """Handle the conversion from plxpr to Catalyst jaxpr for the StatePrep primitive""" state_inval = invals[0] @@ -528,7 +498,7 @@ def handle_state_prep(self, *invals, n_wires, **kwargs): # pylint: disable=unused-argument, too-many-arguments -@QFuncPlxprInterpreter.register_primitive(plxpr_cond_prim) +@SubroutineInterpreter.register_primitive(plxpr_cond_prim) def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive""" args = plxpr_invals[args_slice] @@ -550,8 +520,13 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): converted_jaxpr_branch = jax.make_jaxpr(lambda x: x)(AbstractQreg()).jaxpr else: # Convert branch from plxpr to Catalyst jaxpr + + class Mixin(self.__class__, BranchPlxprInterpreter): + def __init__(self, dev, shots, stateref, actualized): + super().__init__(dev, shots, stateref, actualized) + converted_func = partial( - BranchPlxprInterpreter(self._device, self._shots).eval, + Mixin(self._device, self._shots, self.stateref, self.actualized).eval, plxpr_branch, branch_consts, ) @@ -584,7 +559,7 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): # pylint: disable=unused-argument, too-many-arguments -@QFuncPlxprInterpreter.register_primitive(plxpr_for_loop_prim) +@SubroutineInterpreter.register_primitive(plxpr_for_loop_prim) def handle_for_loop( self, start, @@ -610,8 +585,11 @@ def handle_for_loop( consts = plxpr_invals[consts_slice] # Convert for loop body from plxpr to Catalyst jaxpr + class Mixin(self.__class__, BranchPlxprInterpreter): + def __init__(self, dev, shots, stateref, actualized): + super().__init__(dev, shots, stateref, actualized) converted_func = partial( - BranchPlxprInterpreter(self._device, self._shots).eval, + Mixin(self._device, self._shots, self.stateref, self.actualized).eval, jaxpr_body_fn, consts, ) @@ -643,7 +621,7 @@ def handle_for_loop( # pylint: disable=unused-argument, too-many-arguments -@QFuncPlxprInterpreter.register_primitive(plxpr_while_loop_prim) +@SubroutineInterpreter.register_primitive(plxpr_while_loop_prim) def handle_while_loop( self, *plxpr_invals, @@ -660,8 +638,11 @@ def handle_while_loop( args_plus_qreg = [*args, self.qreg] # Add the qreg to the args # Convert for while body from plxpr to Catalyst jaxpr + class Mixin(self.__class__, BranchPlxprInterpreter): + def __init__(self, dev, shots, stateref, actualized): + super().__init__(dev, shots, stateref, actualized) converted_body_func = partial( - BranchPlxprInterpreter(self._device, self._shots).eval, + Mixin(self.device, self.shots, self.stateref, self.actualized).eval, jaxpr_body_fn, consts_body, ) @@ -703,7 +684,7 @@ def handle_while_loop( return outvals -@QFuncPlxprInterpreter.register_primitive(plxpr_measure_prim) +@SubroutineInterpreter.register_primitive(plxpr_measure_prim) def handle_measure(self, wire, reset, postselect): """Handle the conversion from plxpr to Catalyst jaxpr for the mid-circuit measure primitive.""" @@ -728,7 +709,7 @@ def handle_measure(self, wire, reset, postselect): # pylint: disable=unused-argument, too-many-positional-arguments -@QFuncPlxprInterpreter.register_primitive(plxpr_measure_in_basis_prim) +@SubroutineInterpreter.register_primitive(plxpr_measure_in_basis_prim) def handle_measure_in_basis(self, angle, wire, plane, reset, postselect): """Handle the conversion from plxpr to Catalyst jaxpr for the measure_in_basis primitive""" _angle = jax.lax.convert_element_type(angle, jnp.dtype(jnp.float64)) @@ -753,7 +734,7 @@ def handle_measure_in_basis(self, angle, wire, plane, reset, postselect): # This is due to the registrations being done outside the parent class definition. -class BranchPlxprInterpreter(QFuncPlxprInterpreter): +class BranchPlxprInterpreter(SubroutineInterpreter): """An interpreter that converts a plxpr branch into catalyst-variant jaxpr branch. Args: @@ -761,9 +742,9 @@ class BranchPlxprInterpreter(QFuncPlxprInterpreter): shots (qml.measurements.Shots) """ - def __init__(self, device, shots: qml.measurements.Shots): + def __init__(self, device, shots, stateref, actualized): self._parent_qreg = None - super().__init__(device, shots) + super().__init__(device, shots, stateref, actualized) def setup(self): """Initialize the stateref.""" @@ -795,6 +776,7 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: self._parent_qreg = args[-1] + breakpoint() # Send the original args (without the qreg) outvals = super().eval(jaxpr, consts, *args[:-1]) @@ -805,6 +787,38 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: return outvals +class QFuncPlxprInterpreter(SubroutineInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def __init__(self, device, shots: qml.measurements.Shots | int, stateref, actualized): + super().__init__(device, shots, stateref, actualized) + + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + + class PredicatePlxprInterpreter(PlxprInterpreter): """An interpreter that converts a plxpr predicate into catalyst-variant jaxpr branch.""" From b01c1d993fb29856e03402a809a7a758085ec1a5 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 18:19:32 +0000 Subject: [PATCH 08/16] wip3 --- frontend/catalyst/from_plxpr.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 0ed5c2a51b..1cd1d3d039 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -521,12 +521,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): else: # Convert branch from plxpr to Catalyst jaxpr - class Mixin(self.__class__, BranchPlxprInterpreter): - def __init__(self, dev, shots, stateref, actualized): - super().__init__(dev, shots, stateref, actualized) - converted_func = partial( - Mixin(self._device, self._shots, self.stateref, self.actualized).eval, + BranchPlxprInterpreter(self._device, self._shots, self.stateref, self.actualized).eval, plxpr_branch, branch_consts, ) @@ -585,16 +581,14 @@ def handle_for_loop( consts = plxpr_invals[consts_slice] # Convert for loop body from plxpr to Catalyst jaxpr - class Mixin(self.__class__, BranchPlxprInterpreter): - def __init__(self, dev, shots, stateref, actualized): - super().__init__(dev, shots, stateref, actualized) converted_func = partial( - Mixin(self._device, self._shots, self.stateref, self.actualized).eval, + BranchPlxprInterpreter(self._device, self._shots, self.stateref, self.actualized).eval, jaxpr_body_fn, consts, ) converted_jaxpr_branch = jax.make_jaxpr(converted_func)(*start_plus_args_plus_qreg).jaxpr converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) + breakpoint() # Build Catalyst compatible input values for_loop_invals = [*consts, start, stop, step, *start_plus_args_plus_qreg] @@ -638,11 +632,8 @@ def handle_while_loop( args_plus_qreg = [*args, self.qreg] # Add the qreg to the args # Convert for while body from plxpr to Catalyst jaxpr - class Mixin(self.__class__, BranchPlxprInterpreter): - def __init__(self, dev, shots, stateref, actualized): - super().__init__(dev, shots, stateref, actualized) converted_body_func = partial( - Mixin(self.device, self.shots, self.stateref, self.actualized).eval, + BranchPlxprInterpreter(self.device, self.shots, self.stateref, self.actualized).eval, jaxpr_body_fn, consts_body, ) @@ -774,9 +765,9 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: # We assume we have at least one argument (the qreg) assert len(args) > 0 + breakpoint() self._parent_qreg = args[-1] - breakpoint() # Send the original args (without the qreg) outvals = super().eval(jaxpr, consts, *args[:-1]) From d6138af5e40a358e33165542c43cbe6a52b56337 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 18:19:40 +0000 Subject: [PATCH 09/16] Revert "wip3" This reverts commit b01c1d993fb29856e03402a809a7a758085ec1a5. --- frontend/catalyst/from_plxpr.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 1cd1d3d039..0ed5c2a51b 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -521,8 +521,12 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): else: # Convert branch from plxpr to Catalyst jaxpr + class Mixin(self.__class__, BranchPlxprInterpreter): + def __init__(self, dev, shots, stateref, actualized): + super().__init__(dev, shots, stateref, actualized) + converted_func = partial( - BranchPlxprInterpreter(self._device, self._shots, self.stateref, self.actualized).eval, + Mixin(self._device, self._shots, self.stateref, self.actualized).eval, plxpr_branch, branch_consts, ) @@ -581,14 +585,16 @@ def handle_for_loop( consts = plxpr_invals[consts_slice] # Convert for loop body from plxpr to Catalyst jaxpr + class Mixin(self.__class__, BranchPlxprInterpreter): + def __init__(self, dev, shots, stateref, actualized): + super().__init__(dev, shots, stateref, actualized) converted_func = partial( - BranchPlxprInterpreter(self._device, self._shots, self.stateref, self.actualized).eval, + Mixin(self._device, self._shots, self.stateref, self.actualized).eval, jaxpr_body_fn, consts, ) converted_jaxpr_branch = jax.make_jaxpr(converted_func)(*start_plus_args_plus_qreg).jaxpr converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) - breakpoint() # Build Catalyst compatible input values for_loop_invals = [*consts, start, stop, step, *start_plus_args_plus_qreg] @@ -632,8 +638,11 @@ def handle_while_loop( args_plus_qreg = [*args, self.qreg] # Add the qreg to the args # Convert for while body from plxpr to Catalyst jaxpr + class Mixin(self.__class__, BranchPlxprInterpreter): + def __init__(self, dev, shots, stateref, actualized): + super().__init__(dev, shots, stateref, actualized) converted_body_func = partial( - BranchPlxprInterpreter(self.device, self.shots, self.stateref, self.actualized).eval, + Mixin(self.device, self.shots, self.stateref, self.actualized).eval, jaxpr_body_fn, consts_body, ) @@ -765,9 +774,9 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: # We assume we have at least one argument (the qreg) assert len(args) > 0 - breakpoint() self._parent_qreg = args[-1] + breakpoint() # Send the original args (without the qreg) outvals = super().eval(jaxpr, consts, *args[:-1]) From c52cd0eaf3aef569d26d71bcf9d0820ed96e3041 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 18:19:42 +0000 Subject: [PATCH 10/16] Revert "wip" This reverts commit 42c062d6a23d9fe6254af36347ad5dca779af3c2. --- frontend/catalyst/from_plxpr.py | 124 ++++++++++++++------------------ 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 0ed5c2a51b..5f38de6324 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -188,7 +188,7 @@ def handle_qnode( consts = args[:n_consts] non_const_args = args[n_consts:] - f = partial(QFuncPlxprInterpreter(device, shots, None, True).eval, qfunc_jaxpr, consts) + f = partial(QFuncPlxprInterpreter(device, shots).eval, qfunc_jaxpr, consts) return quantum_kernel_p.bind( wrap_init(f, debug_info=qfunc_jaxpr.debug_info), @@ -271,11 +271,9 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) class SubroutineInterpreter(PlxprInterpreter): - def __init__(self, device, shots, stateref, actualized): - self._device = device - self._shots = self._extract_shots_value(shots) - self.stateref = stateref - self.actualized = actualized + def __init__(self): + self.stateref = None + self.actualized = False super().__init__() def __getattr__(self, key): @@ -365,11 +363,9 @@ def interpret_measurement(self, measurement): else: obs = self._compbasis_obs(*measurement.wires) - shots = self._device.shots.total_shots - shape, dtype = measurement._abstract_eval( n_wires=len(measurement.wires), - shots=shots, + shots=self._device.shots.total_shots, num_device_wires=len(self._device.wires), ) @@ -403,12 +399,46 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 +class QFuncPlxprInterpreter(SubroutineInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def __init__(self, device, shots: qml.measurements.Shots | int): + self._device = device + self._shots = self._extract_shots_value(shots) + super().__init__() + + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + -@SubroutineInterpreter.register_primitive(quantum_subroutine_p) +@QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): # We need to pass the wire as an argument... # And we somehow need to start another interpreter # but only in case it is not yet already available... + raise NotImplementedError() from jax.experimental.pjit import pjit_p from catalyst.jax_primitives import AbstractQreg from jax._src.core import jaxpr_as_fun @@ -450,7 +480,7 @@ def wrapper(qreg, *args): return vals_out -@SubroutineInterpreter.register_primitive(qml.QubitUnitary._primitive) +@QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the QubitUnitary primitive""" wires = [self.get_wire(w) for w in invals[1:]] @@ -460,13 +490,13 @@ def handle_qubit_unitary(self, *invals, n_wires): # pylint: disable=unused-argument -@SubroutineInterpreter.register_primitive(qml.GlobalPhase._primitive) +@QFuncPlxprInterpreter.register_primitive(qml.GlobalPhase._primitive) def handle_global_phase(self, phase, *wires, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the GlobalPhase primitive""" gphase_p.bind(phase, ctrl_len=0, adjoint=False) -@SubroutineInterpreter.register_primitive(qml.BasisState._primitive) +@QFuncPlxprInterpreter.register_primitive(qml.BasisState._primitive) def handle_basis_state(self, *invals, n_wires): """Handle the conversion from plxpr to Catalyst jaxpr for the BasisState primitive""" state_inval = invals[0] @@ -481,7 +511,7 @@ def handle_basis_state(self, *invals, n_wires): # pylint: disable=unused-argument -@SubroutineInterpreter.register_primitive(qml.StatePrep._primitive) +@QFuncPlxprInterpreter.register_primitive(qml.StatePrep._primitive) def handle_state_prep(self, *invals, n_wires, **kwargs): """Handle the conversion from plxpr to Catalyst jaxpr for the StatePrep primitive""" state_inval = invals[0] @@ -498,7 +528,7 @@ def handle_state_prep(self, *invals, n_wires, **kwargs): # pylint: disable=unused-argument, too-many-arguments -@SubroutineInterpreter.register_primitive(plxpr_cond_prim) +@QFuncPlxprInterpreter.register_primitive(plxpr_cond_prim) def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive""" args = plxpr_invals[args_slice] @@ -520,13 +550,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): converted_jaxpr_branch = jax.make_jaxpr(lambda x: x)(AbstractQreg()).jaxpr else: # Convert branch from plxpr to Catalyst jaxpr - - class Mixin(self.__class__, BranchPlxprInterpreter): - def __init__(self, dev, shots, stateref, actualized): - super().__init__(dev, shots, stateref, actualized) - converted_func = partial( - Mixin(self._device, self._shots, self.stateref, self.actualized).eval, + BranchPlxprInterpreter(self._device, self._shots).eval, plxpr_branch, branch_consts, ) @@ -559,7 +584,7 @@ def __init__(self, dev, shots, stateref, actualized): # pylint: disable=unused-argument, too-many-arguments -@SubroutineInterpreter.register_primitive(plxpr_for_loop_prim) +@QFuncPlxprInterpreter.register_primitive(plxpr_for_loop_prim) def handle_for_loop( self, start, @@ -585,11 +610,8 @@ def handle_for_loop( consts = plxpr_invals[consts_slice] # Convert for loop body from plxpr to Catalyst jaxpr - class Mixin(self.__class__, BranchPlxprInterpreter): - def __init__(self, dev, shots, stateref, actualized): - super().__init__(dev, shots, stateref, actualized) converted_func = partial( - Mixin(self._device, self._shots, self.stateref, self.actualized).eval, + BranchPlxprInterpreter(self._device, self._shots).eval, jaxpr_body_fn, consts, ) @@ -621,7 +643,7 @@ def __init__(self, dev, shots, stateref, actualized): # pylint: disable=unused-argument, too-many-arguments -@SubroutineInterpreter.register_primitive(plxpr_while_loop_prim) +@QFuncPlxprInterpreter.register_primitive(plxpr_while_loop_prim) def handle_while_loop( self, *plxpr_invals, @@ -638,11 +660,8 @@ def handle_while_loop( args_plus_qreg = [*args, self.qreg] # Add the qreg to the args # Convert for while body from plxpr to Catalyst jaxpr - class Mixin(self.__class__, BranchPlxprInterpreter): - def __init__(self, dev, shots, stateref, actualized): - super().__init__(dev, shots, stateref, actualized) converted_body_func = partial( - Mixin(self.device, self.shots, self.stateref, self.actualized).eval, + BranchPlxprInterpreter(self._device, self._shots).eval, jaxpr_body_fn, consts_body, ) @@ -684,7 +703,7 @@ def __init__(self, dev, shots, stateref, actualized): return outvals -@SubroutineInterpreter.register_primitive(plxpr_measure_prim) +@QFuncPlxprInterpreter.register_primitive(plxpr_measure_prim) def handle_measure(self, wire, reset, postselect): """Handle the conversion from plxpr to Catalyst jaxpr for the mid-circuit measure primitive.""" @@ -709,7 +728,7 @@ def handle_measure(self, wire, reset, postselect): # pylint: disable=unused-argument, too-many-positional-arguments -@SubroutineInterpreter.register_primitive(plxpr_measure_in_basis_prim) +@QFuncPlxprInterpreter.register_primitive(plxpr_measure_in_basis_prim) def handle_measure_in_basis(self, angle, wire, plane, reset, postselect): """Handle the conversion from plxpr to Catalyst jaxpr for the measure_in_basis primitive""" _angle = jax.lax.convert_element_type(angle, jnp.dtype(jnp.float64)) @@ -734,7 +753,7 @@ def handle_measure_in_basis(self, angle, wire, plane, reset, postselect): # This is due to the registrations being done outside the parent class definition. -class BranchPlxprInterpreter(SubroutineInterpreter): +class BranchPlxprInterpreter(QFuncPlxprInterpreter): """An interpreter that converts a plxpr branch into catalyst-variant jaxpr branch. Args: @@ -742,9 +761,9 @@ class BranchPlxprInterpreter(SubroutineInterpreter): shots (qml.measurements.Shots) """ - def __init__(self, device, shots, stateref, actualized): + def __init__(self, device, shots: qml.measurements.Shots): self._parent_qreg = None - super().__init__(device, shots, stateref, actualized) + super().__init__(device, shots) def setup(self): """Initialize the stateref.""" @@ -776,7 +795,6 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: self._parent_qreg = args[-1] - breakpoint() # Send the original args (without the qreg) outvals = super().eval(jaxpr, consts, *args[:-1]) @@ -787,38 +805,6 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: return outvals -class QFuncPlxprInterpreter(SubroutineInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) - - """ - - def __init__(self, device, shots: qml.measurements.Shots | int, stateref, actualized): - super().__init__(device, shots, stateref, actualized) - - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - - class PredicatePlxprInterpreter(PlxprInterpreter): """An interpreter that converts a plxpr predicate into catalyst-variant jaxpr branch.""" From 6b9ab88d75fcba75548f14ebb7f9cd12455cd765 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 18:19:43 +0000 Subject: [PATCH 11/16] Revert "wip" This reverts commit 817cf1461a67ff3a7083153f0095c41610549467. --- frontend/catalyst/from_plxpr.py | 74 +++++++++++++---------------- frontend/catalyst/jax_primitives.py | 1 + 2 files changed, 33 insertions(+), 42 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 5f38de6324..6e7a95f571 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -167,9 +167,6 @@ def f(x): """ return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) -def from_subroutine(jaxpr): - return jax.make_jaxpr(partial(SubroutineInterpreter().eval, jaxpr.jaxpr, jaxpr.consts)) - class WorkflowInterpreter(PlxprInterpreter): """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" @@ -270,8 +267,19 @@ def wrapper(*args): for pl_transform, (pass_name, decomposition) in transforms_to_passes.items(): register_transform(pl_transform, pass_name, decomposition) -class SubroutineInterpreter(PlxprInterpreter): - def __init__(self): + +class QFuncPlxprInterpreter(PlxprInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def __init__(self, device, shots: qml.measurements.Shots | int): + self._device = device + self._shots = self._extract_shots_value(shots) self.stateref = None self.actualized = False super().__init__() @@ -291,6 +299,25 @@ def __setattr__(self, __name: str, __value) -> None: else: super().__setattr__(__name, __value) + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + def get_wire(self, wire_value) -> AbstractQbit: """Get the ``AbstractQbit`` corresponding to a wire value.""" if wire_value in self.wire_map: @@ -399,46 +426,12 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 -class QFuncPlxprInterpreter(SubroutineInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) - - """ - - def __init__(self, device, shots: qml.measurements.Shots | int): - self._device = device - self._shots = self._extract_shots_value(shots) - super().__init__() - - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - @QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): # We need to pass the wire as an argument... # And we somehow need to start another interpreter # but only in case it is not yet already available... - raise NotImplementedError() from jax.experimental.pjit import pjit_p from catalyst.jax_primitives import AbstractQreg from jax._src.core import jaxpr_as_fun @@ -451,14 +444,11 @@ def wrapper(qreg, *args): return qreg, retval jaxpr = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) - #jaxpr = from_plxpr(jaxpr)(AbstractQreg(), *args) # So, what I need to do here is transform this jaxpr # With `from_plxpr` to but we need to make sure that # the first argument is treated as the qreg... - # quantum_subroutine_p.bind - # is just pjit_p with a different name. vals_out = quantum_subroutine_p.bind(self.qreg, *args, jaxpr=jaxpr, in_shardings=kwargs["in_shardings"], diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 4095199738..44edad6eb8 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -2326,6 +2326,7 @@ def _cos_lowering2(ctx, x, accuracy): return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) def subroutine_lowering(*args, **kwargs): + breakpoint() retval = _pjit_lowering(*args, **kwargs) return retval From b2cf9923eb3d672d3e8d2391d67313e661379650 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 18:20:37 +0000 Subject: [PATCH 12/16] Reapply "wip" This reverts commit 6b9ab88d75fcba75548f14ebb7f9cd12455cd765. --- frontend/catalyst/from_plxpr.py | 74 ++++++++++++++++------------- frontend/catalyst/jax_primitives.py | 1 - 2 files changed, 42 insertions(+), 33 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 6e7a95f571..5f38de6324 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -167,6 +167,9 @@ def f(x): """ return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) +def from_subroutine(jaxpr): + return jax.make_jaxpr(partial(SubroutineInterpreter().eval, jaxpr.jaxpr, jaxpr.consts)) + class WorkflowInterpreter(PlxprInterpreter): """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" @@ -267,19 +270,8 @@ def wrapper(*args): for pl_transform, (pass_name, decomposition) in transforms_to_passes.items(): register_transform(pl_transform, pass_name, decomposition) - -class QFuncPlxprInterpreter(PlxprInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) - - """ - - def __init__(self, device, shots: qml.measurements.Shots | int): - self._device = device - self._shots = self._extract_shots_value(shots) +class SubroutineInterpreter(PlxprInterpreter): + def __init__(self): self.stateref = None self.actualized = False super().__init__() @@ -299,25 +291,6 @@ def __setattr__(self, __name: str, __value) -> None: else: super().__setattr__(__name, __value) - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - def get_wire(self, wire_value) -> AbstractQbit: """Get the ``AbstractQbit`` corresponding to a wire value.""" if wire_value in self.wire_map: @@ -426,12 +399,46 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 +class QFuncPlxprInterpreter(SubroutineInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def __init__(self, device, shots: qml.measurements.Shots | int): + self._device = device + self._shots = self._extract_shots_value(shots) + super().__init__() + + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + @QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): # We need to pass the wire as an argument... # And we somehow need to start another interpreter # but only in case it is not yet already available... + raise NotImplementedError() from jax.experimental.pjit import pjit_p from catalyst.jax_primitives import AbstractQreg from jax._src.core import jaxpr_as_fun @@ -444,11 +451,14 @@ def wrapper(qreg, *args): return qreg, retval jaxpr = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) + #jaxpr = from_plxpr(jaxpr)(AbstractQreg(), *args) # So, what I need to do here is transform this jaxpr # With `from_plxpr` to but we need to make sure that # the first argument is treated as the qreg... + # quantum_subroutine_p.bind + # is just pjit_p with a different name. vals_out = quantum_subroutine_p.bind(self.qreg, *args, jaxpr=jaxpr, in_shardings=kwargs["in_shardings"], diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 44edad6eb8..4095199738 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -2326,7 +2326,6 @@ def _cos_lowering2(ctx, x, accuracy): return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) def subroutine_lowering(*args, **kwargs): - breakpoint() retval = _pjit_lowering(*args, **kwargs) return retval From df7a8d35a80d7cbaa69ec7fee78b9ce38ed3f3a7 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 19:13:37 +0000 Subject: [PATCH 13/16] ok, but drop vars are dropped --- frontend/catalyst/from_plxpr.py | 73 +++++++++++++++++++++++------ frontend/catalyst/jax_primitives.py | 3 ++ 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 5f38de6324..3b37f6158a 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -168,7 +168,7 @@ def f(x): return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) def from_subroutine(jaxpr): - return jax.make_jaxpr(partial(SubroutineInterpreter().eval, jaxpr.jaxpr, jaxpr.consts)) + return jax.make_jaxpr(partial(SubroutineInterpreter(None, 0).eval, jaxpr.jaxpr, jaxpr.consts)) class WorkflowInterpreter(PlxprInterpreter): @@ -271,7 +271,9 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) class SubroutineInterpreter(PlxprInterpreter): - def __init__(self): + def __init__(self, device, shots): + self._device = device + self._shots = self._extract_shots_value(shots) self.stateref = None self.actualized = False super().__init__() @@ -398,8 +400,38 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 + # pylint: disable=too-many-branches + def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: + """Evaluate a jaxpr. + + Args: + jaxpr (jax.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. + + Returns: + list[TensorLike]: the results of the execution. + + """ -class QFuncPlxprInterpreter(SubroutineInterpreter): + # We assume we have at least one argument (the qreg) + assert len(args) > 0 + + self._parent_qreg = args[-1] + self.stateref = {"qreg": self._parent_qreg, "wire_map": {}} + + # Send the original args (without the qreg) + outvals = super().eval(jaxpr, consts, *args) + + # Add the qreg to the output values + outvals = outvals + + self.stateref = None + + return outvals + + +class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter): """An interpreter that converts plxpr into catalyst-variant jaxpr. Args: @@ -408,10 +440,11 @@ class QFuncPlxprInterpreter(SubroutineInterpreter): """ + def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: + return PlxprInterpreter.eval(self, jaxpr, consts, *args) + def __init__(self, device, shots: qml.measurements.Shots | int): - self._device = device - self._shots = self._extract_shots_value(shots) - super().__init__() + super().__init__(device, shots) def setup(self): """Initialize the stateref and bind the device.""" @@ -438,33 +471,43 @@ def handle_subroutine(self, *args, **kwargs): # We need to pass the wire as an argument... # And we somehow need to start another interpreter # but only in case it is not yet already available... - raise NotImplementedError() from jax.experimental.pjit import pjit_p - from catalyst.jax_primitives import AbstractQreg + from catalyst.jax_primitives import AbstractQreg, qextract_p, qinsert_p from jax._src.core import jaxpr_as_fun backup = {orig_wire: wire for orig_wire, wire in self.wire_map.items()} self.actualize_qreg() + plxpr = kwargs["jaxpr"] def wrapper(qreg, *args): - retval = jaxpr_as_fun(kwargs["jaxpr"], *args)() + #qubit = qextract_p.bind(qreg, 0) + retval = jaxpr_as_fun(plxpr, *args)() + #qreg = qinsert_p.bind(qreg, 0, qubit) return qreg, retval - jaxpr = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) + jaxpr_with_parameter_and_return = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) + converted_func = partial( + SubroutineInterpreter(self._device, self._shots).eval, + jaxpr_with_parameter_and_return.jaxpr, + jaxpr_with_parameter_and_return.consts, + ) + converted_jaxpr_branch = jax.make_jaxpr(converted_func)(self.qreg, *args).jaxpr + converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) #jaxpr = from_plxpr(jaxpr)(AbstractQreg(), *args) # So, what I need to do here is transform this jaxpr # With `from_plxpr` to but we need to make sure that # the first argument is treated as the qreg... + from jax._src.sharding_impls import UnspecifiedValue # quantum_subroutine_p.bind # is just pjit_p with a different name. vals_out = quantum_subroutine_p.bind(self.qreg, *args, - jaxpr=jaxpr, - in_shardings=kwargs["in_shardings"], - out_shardings=kwargs["out_shardings"], - in_layouts=kwargs["in_layouts"], - out_layouts=kwargs["out_layouts"], + jaxpr=converted_closed_jaxpr_branch, + in_shardings=(UnspecifiedValue(), *kwargs["in_shardings"]), + out_shardings=(UnspecifiedValue(), *kwargs["out_shardings"]), + in_layouts=(None, *kwargs["in_layouts"]), + out_layouts=(None, *kwargs["out_layouts"]), donated_invars=kwargs["donated_invars"], ctx_mesh=kwargs["ctx_mesh"], name=kwargs["name"], diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 4095199738..484e44fa41 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -162,6 +162,9 @@ def __eq__(self, other): def __hash__(self): return self.hash_value + def _add(self, first, second): + return AbstractQreg() + class ConcreteQreg: """Concrete quantum register.""" From 38bc6daca8e54f54e72a40e378010ac524d9e222 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 19:28:20 +0000 Subject: [PATCH 14/16] execution works --- frontend/catalyst/from_plxpr.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 3b37f6158a..351a8de355 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -424,7 +424,11 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: outvals = super().eval(jaxpr, consts, *args) # Add the qreg to the output values - outvals = outvals + self.qreg, retvals = outvals[0], outvals[1:] + + self.actualize_qreg() + + outvals = (self.qreg, *retvals) self.stateref = None @@ -465,6 +469,7 @@ def cleanup(self): device_release_p.bind() self.stateref = None +from catalyst.jax_primitives import qinsert_p @QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) def handle_subroutine(self, *args, **kwargs): From e09a6d8b9dd8eddf69540d7b5a30011ac018c2ba Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 19:31:07 +0000 Subject: [PATCH 15/16] typo --- frontend/catalyst/from_plxpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 351a8de355..f466fa4933 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -520,7 +520,7 @@ def wrapper(qreg, *args): inline=kwargs["inline"], compiler_options_kvs=kwargs["compiler_options_kvs"]) - self.qjit = vals_out[0] + self.qreg = vals_out[0] vals_out = vals_out[1:] for orig_wire, _ in backup: From 401986c0019afdd8aa6ed4d2c5cf087ca9f17e93 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Jun 2025 19:35:07 +0000 Subject: [PATCH 16/16] use first parameter --- frontend/catalyst/from_plxpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index f466fa4933..bb31330a95 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -417,7 +417,7 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: # We assume we have at least one argument (the qreg) assert len(args) > 0 - self._parent_qreg = args[-1] + self._parent_qreg = args[0] self.stateref = {"qreg": self._parent_qreg, "wire_map": {}} # Send the original args (without the qreg)