diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f486fb6824..f4ea8bbe4a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -249,6 +249,9 @@ This runtime stub is currently for mock execution only and should be treated as a placeholder operation. Internally, it functions just as a computational-basis measurement instruction. +* Quantum subroutine stub. + [(#1774)](https://github.com/PennyLaneAI/catalyst/pull/1774) + * PennyLane's arbitrary-basis measurement operations, such as :func:`qml.ftqc.measure_arbitrary_basis() `, are now QJIT-compatible with program capture enabled. diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index ae3c67f730..bb31330a95 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -68,6 +68,7 @@ qinsert_p, qinst_p, quantum_kernel_p, + quantum_subroutine_p, sample_p, set_basis_state_p, set_state_p, @@ -166,6 +167,9 @@ def f(x): """ return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts)) +def from_subroutine(jaxpr): + return jax.make_jaxpr(partial(SubroutineInterpreter(None, 0).eval, jaxpr.jaxpr, jaxpr.consts)) + class WorkflowInterpreter(PlxprInterpreter): """An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant.""" @@ -266,17 +270,8 @@ def wrapper(*args): for pl_transform, (pass_name, decomposition) in transforms_to_passes.items(): register_transform(pl_transform, pass_name, decomposition) - -class QFuncPlxprInterpreter(PlxprInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) - - """ - - def __init__(self, device, shots: qml.measurements.Shots | int): +class SubroutineInterpreter(PlxprInterpreter): + def __init__(self, device, shots): self._device = device self._shots = self._extract_shots_value(shots) self.stateref = None @@ -298,25 +293,6 @@ def __setattr__(self, __name: str, __value) -> None: else: super().__setattr__(__name, __value) - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - def get_wire(self, wire_value) -> AbstractQbit: """Get the ``AbstractQbit`` corresponding to a wire value.""" if wire_value in self.wire_map: @@ -424,6 +400,133 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 + # pylint: disable=too-many-branches + def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: + """Evaluate a jaxpr. + + Args: + jaxpr (jax.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. + + Returns: + list[TensorLike]: the results of the execution. + + """ + + # We assume we have at least one argument (the qreg) + assert len(args) > 0 + + self._parent_qreg = args[0] + self.stateref = {"qreg": self._parent_qreg, "wire_map": {}} + + # Send the original args (without the qreg) + outvals = super().eval(jaxpr, consts, *args) + + # Add the qreg to the output values + self.qreg, retvals = outvals[0], outvals[1:] + + self.actualize_qreg() + + outvals = (self.qreg, *retvals) + + self.stateref = None + + return outvals + + +class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: + return PlxprInterpreter.eval(self, jaxpr, consts, *args) + + def __init__(self, device, shots: qml.measurements.Shots | int): + super().__init__(device, shots) + + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + +from catalyst.jax_primitives import qinsert_p + +@QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p) +def handle_subroutine(self, *args, **kwargs): + # We need to pass the wire as an argument... + # And we somehow need to start another interpreter + # but only in case it is not yet already available... + from jax.experimental.pjit import pjit_p + from catalyst.jax_primitives import AbstractQreg, qextract_p, qinsert_p + from jax._src.core import jaxpr_as_fun + + backup = {orig_wire: wire for orig_wire, wire in self.wire_map.items()} + self.actualize_qreg() + plxpr = kwargs["jaxpr"] + + def wrapper(qreg, *args): + #qubit = qextract_p.bind(qreg, 0) + retval = jaxpr_as_fun(plxpr, *args)() + #qreg = qinsert_p.bind(qreg, 0, qubit) + return qreg, retval + + jaxpr_with_parameter_and_return = jax.make_jaxpr(wrapper)(AbstractQreg(), *args) + converted_func = partial( + SubroutineInterpreter(self._device, self._shots).eval, + jaxpr_with_parameter_and_return.jaxpr, + jaxpr_with_parameter_and_return.consts, + ) + converted_jaxpr_branch = jax.make_jaxpr(converted_func)(self.qreg, *args).jaxpr + converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ()) + #jaxpr = from_plxpr(jaxpr)(AbstractQreg(), *args) + # So, what I need to do here is transform this jaxpr + # With `from_plxpr` to but we need to make sure that + # the first argument is treated as the qreg... + + + from jax._src.sharding_impls import UnspecifiedValue + # quantum_subroutine_p.bind + # is just pjit_p with a different name. + vals_out = quantum_subroutine_p.bind(self.qreg, *args, + jaxpr=converted_closed_jaxpr_branch, + in_shardings=(UnspecifiedValue(), *kwargs["in_shardings"]), + out_shardings=(UnspecifiedValue(), *kwargs["out_shardings"]), + in_layouts=(None, *kwargs["in_layouts"]), + out_layouts=(None, *kwargs["out_layouts"]), + donated_invars=kwargs["donated_invars"], + ctx_mesh=kwargs["ctx_mesh"], + name=kwargs["name"], + keep_unused=kwargs["keep_unused"], + inline=kwargs["inline"], + compiler_options_kvs=kwargs["compiler_options_kvs"]) + + self.qreg = vals_out[0] + vals_out = vals_out[1:] + + for orig_wire, _ in backup: + self.wire_map[orig_wire] = qextract_p.bind(self.qreg, orig_wire) + + return vals_out @QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 31bc822203..484e44fa41 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -18,6 +18,7 @@ import sys from dataclasses import dataclass from enum import Enum +import functools from itertools import chain from typing import Iterable, List, Union @@ -138,13 +139,13 @@ def __hash__(self): # pragma: nocover return self.hash_value -class ConcreteQbit(AbstractQbit): +class ConcreteQbit: """Concrete Qbit.""" def _qbit_lowering(aval): assert isinstance(aval, AbstractQbit) - return (ir.OpaqueType.get("quantum", "bit"),) + return ir.OpaqueType.get("quantum", "bit") # @@ -161,14 +162,17 @@ def __eq__(self, other): def __hash__(self): return self.hash_value + def _add(self, first, second): + return AbstractQreg() -class ConcreteQreg(AbstractQreg): + +class ConcreteQreg: """Concrete quantum register.""" def _qreg_lowering(aval): assert isinstance(aval, AbstractQreg) - return (ir.OpaqueType.get("quantum", "reg"),) + return ir.OpaqueType.get("quantum", "reg") # @@ -305,6 +309,27 @@ class MeasurementPlane(Enum): measure_in_basis_p = Primitive("measure_in_basis") measure_in_basis_p.multiple_results = True +from jax.experimental.pjit import pjit_p +from jax._src.pjit import _pjit_lowering +import copy +from catalyst.utils.patching import Patcher +quantum_subroutine_p = copy.deepcopy(pjit_p) +quantum_subroutine_p.name = "quantum_subroutine_p" + +def subroutine(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with Patcher( + ( + jax._src.pjit, + "pjit_p", + quantum_subroutine_p, + ), + ): + return jax.jit(func)(*args, **kwargs) + return wrapper + def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr): assert len(jaxpr.consts) == 0, ( @@ -2303,6 +2328,9 @@ def _cos_lowering2(ctx, x, accuracy): """Use hlo.cosine lowering instead of the new cosine lowering from jax 0.4.28""" return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy) +def subroutine_lowering(*args, **kwargs): + retval = _pjit_lowering(*args, **kwargs) + return retval CUSTOM_LOWERING_RULES = ( (zne_p, _zne_lowering), @@ -2344,6 +2372,7 @@ def _cos_lowering2(ctx, x, accuracy): (sin_p, _sin_lowering2), (cos_p, _cos_lowering2), (quantum_kernel_p, _quantum_kernel_lowering), + (quantum_subroutine_p, subroutine_lowering), (measure_in_basis_p, _measure_in_basis_lowering), ) @@ -2357,3 +2386,5 @@ def _scalar_abstractify(t): pytype_aval_mappings[type] = _scalar_abstractify pytype_aval_mappings[jax._src.numpy.scalar_types._ScalarMeta] = _scalar_abstractify +pytype_aval_mappings[ConcreteQbit] = lambda _: AbstractQbit() +pytype_aval_mappings[ConcreteQreg] = lambda _: AbstractQreg() diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 137154a9e6..b4090001c4 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -25,6 +25,7 @@ from jax._src.core import shaped_abstractify from jax._src.interpreters.partial_eval import infer_lambda_input_type from jax._src.pjit import _flat_axes_specs +from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten from catalyst.jax_extras import get_aval2 @@ -56,7 +57,7 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, jax.core.ShapedArray)) for annotation in annotations) + return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations) def get_type_annotations(fn: Callable):