Skip to content

🚧 Quantum Subroutine #1785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@
This runtime stub is currently for mock execution only and should be treated as a placeholder
operation. Internally, it functions just as a computational-basis measurement instruction.

* Quantum subroutine stub.
[(#1774)](https://github.com/PennyLaneAI/catalyst/pull/1774)

* PennyLane's arbitrary-basis measurement operations, such as
:func:`qml.ftqc.measure_arbitrary_basis() <pennylane.ftqc.measure_arbitrary_basis>`, are now
QJIT-compatible with program capture enabled.
Expand Down
163 changes: 133 additions & 30 deletions frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
qinsert_p,
qinst_p,
quantum_kernel_p,
quantum_subroutine_p,
sample_p,
set_basis_state_p,
set_state_p,
Expand Down Expand Up @@ -166,6 +167,9 @@
"""
return jax.make_jaxpr(partial(WorkflowInterpreter().eval, plxpr.jaxpr, plxpr.consts))

def from_subroutine(jaxpr):

Check notice on line 170 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L170

Missing function or method docstring (missing-function-docstring)
return jax.make_jaxpr(partial(SubroutineInterpreter(None, 0).eval, jaxpr.jaxpr, jaxpr.consts))


class WorkflowInterpreter(PlxprInterpreter):
"""An interpreter that converts a qnode primitive from a plxpr variant to a catalxpr variant."""
Expand Down Expand Up @@ -266,17 +270,8 @@
for pl_transform, (pass_name, decomposition) in transforms_to_passes.items():
register_transform(pl_transform, pass_name, decomposition)


class QFuncPlxprInterpreter(PlxprInterpreter):
"""An interpreter that converts plxpr into catalyst-variant jaxpr.

Args:
device (qml.devices.Device)
shots (qml.measurements.Shots)

"""

def __init__(self, device, shots: qml.measurements.Shots | int):
class SubroutineInterpreter(PlxprInterpreter):

Check notice on line 273 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L273

Missing class docstring (missing-class-docstring)
def __init__(self, device, shots):
self._device = device
self._shots = self._extract_shots_value(shots)
self.stateref = None
Expand All @@ -298,25 +293,6 @@
else:
super().__setattr__(__name, __value)

def setup(self):
"""Initialize the stateref and bind the device."""
if self.stateref is None:
device_init_p.bind(self._shots, **_get_device_kwargs(self._device))
self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}}

# pylint: disable=attribute-defined-outside-init
def cleanup(self):
"""Perform any final steps after processing the plxpr.

For conversion to calayst, this reinserts extracted qubits and
deallocates the register, and releases the device.
"""
if not self.actualized:
self.actualize_qreg()
qdealloc_p.bind(self.qreg)
device_release_p.bind()
self.stateref = None

def get_wire(self, wire_value) -> AbstractQbit:
"""Get the ``AbstractQbit`` corresponding to a wire value."""
if wire_value in self.wire_map:
Expand All @@ -331,7 +307,7 @@
"""
self.actualized = True
for orig_wire, wire in self.wire_map.items():
self.qreg = qinsert_p.bind(self.qreg, orig_wire, wire)

Check notice on line 310 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L310

Attribute 'qreg' defined outside __init__ (attribute-defined-outside-init)

def interpret_operation(self, op):
"""Re-bind a pennylane operation as a catalyst instruction."""
Expand Down Expand Up @@ -424,6 +400,133 @@

return shots.total_shots if shots else 0

# pylint: disable=too-many-branches
def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
"""Evaluate a jaxpr.

Args:
jaxpr (jax.core.Jaxpr): the jaxpr to evaluate
consts (list[TensorLike]): the constant variables for the jaxpr
*args (tuple[TensorLike]): The arguments for the jaxpr.

Returns:
list[TensorLike]: the results of the execution.

"""

# We assume we have at least one argument (the qreg)
assert len(args) > 0

self._parent_qreg = args[0]

Check notice on line 420 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L420

Attribute '_parent_qreg' defined outside __init__ (attribute-defined-outside-init)
self.stateref = {"qreg": self._parent_qreg, "wire_map": {}}

# Send the original args (without the qreg)
outvals = super().eval(jaxpr, consts, *args)

# Add the qreg to the output values
self.qreg, retvals = outvals[0], outvals[1:]

Check notice on line 427 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L427

Attribute 'qreg' defined outside __init__ (attribute-defined-outside-init)

self.actualize_qreg()

outvals = (self.qreg, *retvals)

self.stateref = None

return outvals


class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter):
"""An interpreter that converts plxpr into catalyst-variant jaxpr.

Args:
device (qml.devices.Device)
shots (qml.measurements.Shots)

"""

def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list:
return PlxprInterpreter.eval(self, jaxpr, consts, *args)

def __init__(self, device, shots: qml.measurements.Shots | int):

Check notice on line 450 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L450

Useless parent or super() delegation in method '__init__' (useless-parent-delegation)
super().__init__(device, shots)

def setup(self):
"""Initialize the stateref and bind the device."""
if self.stateref is None:
device_init_p.bind(self._shots, **_get_device_kwargs(self._device))
self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}}

# pylint: disable=attribute-defined-outside-init
def cleanup(self):
"""Perform any final steps after processing the plxpr.

For conversion to calayst, this reinserts extracted qubits and
deallocates the register, and releases the device.
"""
if not self.actualized:
self.actualize_qreg()
qdealloc_p.bind(self.qreg)
device_release_p.bind()
self.stateref = None

from catalyst.jax_primitives import qinsert_p

Check notice on line 472 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L472

Import "from catalyst.jax_primitives import qinsert_p" should be placed at the top of the module (wrong-import-position)

@QFuncPlxprInterpreter.register_primitive(quantum_subroutine_p)
def handle_subroutine(self, *args, **kwargs):

Check notice on line 475 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L475

Missing function or method docstring (missing-function-docstring)
# We need to pass the wire as an argument...
# And we somehow need to start another interpreter
# but only in case it is not yet already available...
from jax.experimental.pjit import pjit_p

Check notice on line 479 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L479

Import outside toplevel (jax.experimental.pjit.pjit_p) (import-outside-toplevel)

Check notice on line 479 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L479

Unused pjit_p imported from jax.experimental.pjit (unused-import)
from catalyst.jax_primitives import AbstractQreg, qextract_p, qinsert_p

Check notice on line 480 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L480

Redefining name 'AbstractQreg' from outer scope (line 49) (redefined-outer-name)

Check notice on line 480 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L480

Redefining name 'qinsert_p' from outer scope (line 49) (redefined-outer-name)

Check notice on line 480 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L480

Import outside toplevel (catalyst.jax_primitives.AbstractQreg, catalyst.jax_primitives.qextract_p, catalyst.jax_primitives.qinsert_p) (import-outside-toplevel)

Check notice on line 480 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L480

Redefining name 'qextract_p' from outer scope (line 49) (redefined-outer-name)

Check notice on line 480 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L480

Unused qinsert_p imported from catalyst.jax_primitives (unused-import)
from jax._src.core import jaxpr_as_fun

Check notice on line 481 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L481

Import outside toplevel (jax._src.core.jaxpr_as_fun) (import-outside-toplevel)

backup = {orig_wire: wire for orig_wire, wire in self.wire_map.items()}

Check notice on line 483 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L483

Unnecessary use of a comprehension, use dict(self.wire_map.items()) instead. (unnecessary-comprehension)
self.actualize_qreg()
plxpr = kwargs["jaxpr"]

def wrapper(qreg, *args):
#qubit = qextract_p.bind(qreg, 0)
retval = jaxpr_as_fun(plxpr, *args)()
#qreg = qinsert_p.bind(qreg, 0, qubit)
return qreg, retval

jaxpr_with_parameter_and_return = jax.make_jaxpr(wrapper)(AbstractQreg(), *args)
converted_func = partial(
SubroutineInterpreter(self._device, self._shots).eval,
jaxpr_with_parameter_and_return.jaxpr,
jaxpr_with_parameter_and_return.consts,
)
converted_jaxpr_branch = jax.make_jaxpr(converted_func)(self.qreg, *args).jaxpr
converted_closed_jaxpr_branch = ClosedJaxpr(convert_constvars_jaxpr(converted_jaxpr_branch), ())
#jaxpr = from_plxpr(jaxpr)(AbstractQreg(), *args)
# So, what I need to do here is transform this jaxpr
# With `from_plxpr` to but we need to make sure that
# the first argument is treated as the qreg...


from jax._src.sharding_impls import UnspecifiedValue

Check notice on line 507 in frontend/catalyst/from_plxpr.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr.py#L507

Import outside toplevel (jax._src.sharding_impls.UnspecifiedValue) (import-outside-toplevel)
# quantum_subroutine_p.bind
# is just pjit_p with a different name.
vals_out = quantum_subroutine_p.bind(self.qreg, *args,
jaxpr=converted_closed_jaxpr_branch,
in_shardings=(UnspecifiedValue(), *kwargs["in_shardings"]),
out_shardings=(UnspecifiedValue(), *kwargs["out_shardings"]),
in_layouts=(None, *kwargs["in_layouts"]),
out_layouts=(None, *kwargs["out_layouts"]),
donated_invars=kwargs["donated_invars"],
ctx_mesh=kwargs["ctx_mesh"],
name=kwargs["name"],
keep_unused=kwargs["keep_unused"],
inline=kwargs["inline"],
compiler_options_kvs=kwargs["compiler_options_kvs"])

self.qreg = vals_out[0]
vals_out = vals_out[1:]

for orig_wire, _ in backup:
self.wire_map[orig_wire] = qextract_p.bind(self.qreg, orig_wire)

return vals_out

@QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive)
def handle_qubit_unitary(self, *invals, n_wires):
Expand Down
39 changes: 35 additions & 4 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
from dataclasses import dataclass
from enum import Enum
import functools
from itertools import chain
from typing import Iterable, List, Union

Expand Down Expand Up @@ -138,13 +139,13 @@
return self.hash_value


class ConcreteQbit(AbstractQbit):
class ConcreteQbit:
"""Concrete Qbit."""


def _qbit_lowering(aval):
assert isinstance(aval, AbstractQbit)
return (ir.OpaqueType.get("quantum", "bit"),)
return ir.OpaqueType.get("quantum", "bit")


#
Expand All @@ -161,14 +162,17 @@
def __hash__(self):
return self.hash_value

def _add(self, first, second):
return AbstractQreg()

class ConcreteQreg(AbstractQreg):

class ConcreteQreg:
"""Concrete quantum register."""


def _qreg_lowering(aval):
assert isinstance(aval, AbstractQreg)
return (ir.OpaqueType.get("quantum", "reg"),)
return ir.OpaqueType.get("quantum", "reg")


#
Expand Down Expand Up @@ -305,6 +309,27 @@
measure_in_basis_p = Primitive("measure_in_basis")
measure_in_basis_p.multiple_results = True

from jax.experimental.pjit import pjit_p

Check notice on line 312 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L312

Imports from package jax are not grouped (ungrouped-imports)

Check notice on line 312 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L312

Import "from jax.experimental.pjit import pjit_p" should be placed at the top of the module (wrong-import-position)
from jax._src.pjit import _pjit_lowering

Check notice on line 313 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L313

Import "from jax._src.pjit import _pjit_lowering" should be placed at the top of the module (wrong-import-position)
import copy

Check notice on line 314 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L314

Import "import copy" should be placed at the top of the module (wrong-import-position)
from catalyst.utils.patching import Patcher

Check notice on line 315 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L315

Import "from catalyst.utils.patching import Patcher" should be placed at the top of the module (wrong-import-position)

Check notice on line 315 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L315

Imports from package catalyst are not grouped (ungrouped-imports)
quantum_subroutine_p = copy.deepcopy(pjit_p)
quantum_subroutine_p.name = "quantum_subroutine_p"

def subroutine(func):

Check notice on line 319 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L319

Missing function or method docstring (missing-function-docstring)

@functools.wraps(func)
def wrapper(*args, **kwargs):
with Patcher(
(
jax._src.pjit,
"pjit_p",
quantum_subroutine_p,
),
):
return jax.jit(func)(*args, **kwargs)
return wrapper


def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr):
assert len(jaxpr.consts) == 0, (
Expand Down Expand Up @@ -2303,6 +2328,9 @@
"""Use hlo.cosine lowering instead of the new cosine lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.cosine, ctx, x, accuracy=accuracy)

def subroutine_lowering(*args, **kwargs):

Check notice on line 2331 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L2331

Missing function or method docstring (missing-function-docstring)
retval = _pjit_lowering(*args, **kwargs)
return retval

CUSTOM_LOWERING_RULES = (
(zne_p, _zne_lowering),
Expand Down Expand Up @@ -2344,6 +2372,7 @@
(sin_p, _sin_lowering2),
(cos_p, _cos_lowering2),
(quantum_kernel_p, _quantum_kernel_lowering),
(quantum_subroutine_p, subroutine_lowering),
(measure_in_basis_p, _measure_in_basis_lowering),
)

Expand All @@ -2357,3 +2386,5 @@

pytype_aval_mappings[type] = _scalar_abstractify
pytype_aval_mappings[jax._src.numpy.scalar_types._ScalarMeta] = _scalar_abstractify
pytype_aval_mappings[ConcreteQbit] = lambda _: AbstractQbit()
pytype_aval_mappings[ConcreteQreg] = lambda _: AbstractQreg()
3 changes: 2 additions & 1 deletion frontend/catalyst/tracing/type_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from jax._src.core import shaped_abstractify
from jax._src.interpreters.partial_eval import infer_lambda_input_type
from jax._src.pjit import _flat_axes_specs
from jax.core import AbstractValue
from jax.tree_util import tree_flatten, tree_unflatten

from catalyst.jax_extras import get_aval2
Expand Down Expand Up @@ -56,7 +57,7 @@ def params_are_annotated(fn: Callable):
are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations)
if not are_annotated:
return False
return all(isinstance(annotation, (type, jax.core.ShapedArray)) for annotation in annotations)
return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations)


def get_type_annotations(fn: Callable):
Expand Down
Loading