Skip to content

[FTQC] Add support for parametric mid-circuit measurements #1645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ff14c48
Initial commit
joeycarter Apr 11, 2025
73b000c
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter Apr 21, 2025
20df143
[WIP] Add lowering rules for measure_in_basis
joeycarter Apr 22, 2025
8e91640
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter Apr 22, 2025
6a5eabc
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter Apr 25, 2025
5ab23d1
xfail test_measure_z
joeycarter Apr 28, 2025
9e7b3bd
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter Apr 28, 2025
6535b0a
Get MeasurementPlane enum working in capture
joeycarter Apr 28, 2025
2a84738
Update tests
joeycarter Apr 28, 2025
39e1b1b
Fix Codefactor warnings
joeycarter Apr 28, 2025
314d7f4
Clean up
joeycarter Apr 28, 2025
caa578e
Add changelog entry
joeycarter Apr 28, 2025
b0ad8d3
Fix docs build
joeycarter Apr 28, 2025
f964c9e
Disable capture after calling workloads
joeycarter Apr 29, 2025
04d069a
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter Apr 29, 2025
664a2e7
Relax test assert to be -1.0 <= result <= 1.0
joeycarter Apr 29, 2025
063a092
Change measure_in_basis input arg order
joeycarter May 1, 2025
b7f98e9
Add more tests
joeycarter May 1, 2025
23ba914
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter May 1, 2025
b0c8d0e
Ignore too-many-positional-arguments pylint warning
joeycarter May 1, 2025
a8a2d43
Add pseudo MBQC workload lit test
joeycarter May 1, 2025
a1d16e7
Bump PL dependency to 0.42.0-dev19
joeycarter May 1, 2025
2f0fb0c
Merge branch 'main' into joeycarter/arb-basis-measurements
joeycarter May 2, 2025
1c20cec
Update PL dep to 0.42.0-dev19 in docs requirements
joeycarter May 2, 2025
72ed935
Improve testing coverage
joeycarter May 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ enzyme=v0.0.149

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt
pennylane=0.42.0-dev15
pennylane=0.42.0-dev19

# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __getattr__(cls, name):
"mlir_quantum.dialects.quantum",
"mlir_quantum.dialects.gradient",
"mlir_quantum.dialects.catalyst",
"mlir_quantum.dialects.mbqc",
"mlir_quantum.dialects.mitigation",
"mlir_quantum.dialects._transform_ops_gen",
"pybind11",
Expand Down
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@
This runtime stub is currently for mock execution only and should be treated as a placeholder
operation. Internally, it functions just as a computational-basis measurement instruction.

* PennyLane's arbitrary-basis measurement operations, such as [`qml.ftqc.measure_arbitrary_basis()`
](https://docs.pennylane.ai/en/stable/code/api/pennylane.ftqc.measure_arbitrary_basis.html), are
now QJIT-compatible with program capture enabled.
[(#1645)](https://github.com/PennyLaneAI/catalyst/pull/1645)

* The utility function `EnsureFunctionDeclaration` is refactored into the `Utils` of the `Catalyst`
dialect, instead of being duplicated in each individual dialect.
[(#1683)](https://github.com/PennyLaneAI/catalyst/pull/1683)
Expand Down
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ lxml_html_clean
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.42.0-dev11
pennylane-lightning==0.42.0-dev11
pennylane==0.42.0-dev15
pennylane==0.42.0-dev19
24 changes: 24 additions & 0 deletions frontend/catalyst/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pennylane.capture.primitives import cond_prim as plxpr_cond_prim
from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.ops.functions.map_wires import _map_wires_transform as pl_map_wires
from pennylane.transforms import cancel_inverses as pl_cancel_inverses
from pennylane.transforms import commute_controlled as pl_commute_controlled
Expand All @@ -47,12 +48,14 @@
from catalyst.jax_primitives import (
AbstractQbit,
AbstractQreg,
MeasurementPlane,
compbasis_p,
cond_p,
counts_p,
expval_p,
for_p,
gphase_p,
measure_in_basis_p,
namedobs_p,
probs_p,
qalloc_p,
Expand Down Expand Up @@ -640,6 +643,27 @@ def handle_while_loop(
return outvals


# pylint: disable=unused-argument, too-many-positional-arguments
@QFuncPlxprInterpreter.register_primitive(plxpr_measure_in_basis_prim)
def handle_measure_in_basis(self, angle, wire, plane, reset, postselect):
"""Handle the conversion from plxpr to Catalyst jaxpr for the measure_in_basis primitive"""
_angle = jax.lax.convert_element_type(angle, jnp.dtype(jnp.float64))

try:
_plane = MeasurementPlane(plane)
except ValueError as e:
raise ValueError(
f"Measurement plane must be one of {[plane.value for plane in MeasurementPlane]}"
) from e

in_wire = self.get_wire(wire)
result, out_wire = measure_in_basis_p.bind(_angle, in_wire, plane=_plane, postselect=postselect)

self.wire_map[wire] = out_wire

return result


# Derived interpreters must be declared after the primitive registrations of their
# parents or be placed in a separate file, in order to access those registrations.
# This is due to the registrations being done outside the parent class definition.
Expand Down
87 changes: 87 additions & 0 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
ValueAndGradOp,
VJPOp,
)
from mlir_quantum.dialects.mbqc import MeasureInBasisOp
from mlir_quantum.dialects.mitigation import ZneOp
from mlir_quantum.dialects.quantum import (
AdjointOp,
Expand Down Expand Up @@ -230,6 +231,16 @@ class Folding(Enum):
ALL = "local-all"


class MeasurementPlane(Enum):
"""
Measurement planes for arbitrary-basis measurements in MBQC
"""

XY = "XY"
YZ = "YZ"
ZX = "ZX"


##############
# Primitives #
##############
Expand Down Expand Up @@ -292,6 +303,8 @@ class Folding(Enum):
set_basis_state_p.multiple_results = True
quantum_kernel_p = core.CallPrimitive("quantum_kernel")
quantum_kernel_p.multiple_results = True
measure_in_basis_p = core.Primitive("measure_in_basis")
measure_in_basis_p.multiple_results = True


def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr):
Expand Down Expand Up @@ -1214,6 +1227,79 @@ def _qmeasure_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value, posts
)


#
# arbitrary-basis measurements
#
@measure_in_basis_p.def_abstract_eval
def _measure_in_basis_abstract_eval(
angle: float, qubit: AbstractQbit, plane: MeasurementPlane, postselect: int = None
):
assert isinstance(qubit, AbstractQbit)
return core.ShapedArray((), bool), qubit


@measure_in_basis_p.def_impl
def _measure_in_basis_def_impl(
ctx, angle: float, qubit: AbstractQbit, plane: MeasurementPlane, postselect: int = None
): # pragma: no cover
raise NotImplementedError()


def _measurement_plane_attribute(ctx, plane: MeasurementPlane):
return ir.OpaqueAttr.get(
"mbqc",
("measurement_plane " + MeasurementPlane(plane).name).encode("utf-8"),
ir.NoneType.get(ctx),
ctx,
)


def _measure_in_basis_lowering(
jax_ctx: mlir.LoweringRuleContext,
angle: float,
qubit: ir.Value,
plane: MeasurementPlane,
postselect: int = None,
):
ctx = jax_ctx.module_context.context
ctx.allow_unregistered_dialects = True

assert ir.OpaqueType.isinstance(qubit.type)
assert ir.OpaqueType(qubit.type).dialect_namespace == "quantum"
assert ir.OpaqueType(qubit.type).data == "bit"

angle = safe_cast_to_f64(angle, "angle")
angle = extract_scalar(angle, "angle")

assert ir.F64Type.isinstance(
angle.type
), "Only scalar double parameters are allowed for quantum gates!"

# Prepare postselect attribute
if postselect is not None:
i32_type = ir.IntegerType.get_signless(32, ctx)
postselect = ir.IntegerAttr.get(i32_type, postselect)

result_type = ir.IntegerType.get_signless(1)

result, new_qubit = MeasureInBasisOp(
result_type,
qubit.type,
qubit,
plane=_measurement_plane_attribute(ctx, plane),
angle=angle,
postselect=postselect,
).results

result_from_elements_op = ir.RankedTensorType.get((), result.type)
from_elements_op = FromElementsOp(result_from_elements_op, result)

return (
from_elements_op.results[0],
new_qubit,
)


#
# compbasis observable
#
Expand Down Expand Up @@ -2250,6 +2336,7 @@ def _cos_lowering2(ctx, x):
(sin_p, _sin_lowering2),
(cos_p, _cos_lowering2),
(quantum_kernel_p, _quantum_kernel_lowering),
(measure_in_basis_p, _measure_in_basis_lowering),
)


Expand Down
Loading