From ec15d1ff4f5a6105f89544d702a44ef338668000 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Thu, 5 Jun 2025 18:45:43 +0000 Subject: [PATCH 1/5] added subroutine interpreter --- frontend/catalyst/from_plxpr.py | 114 +++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 30 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 01be1ad22..b3e27999a 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -267,16 +267,16 @@ def wrapper(*args): register_transform(pl_transform, pass_name, decomposition) -class QFuncPlxprInterpreter(PlxprInterpreter): - """An interpreter that converts plxpr into catalyst-variant jaxpr. - - Args: - device (qml.devices.Device) - shots (qml.measurements.Shots) +class SubroutineInterpreter(PlxprInterpreter): + """Base interpreter for quantum operations. + It is a subroutine interpreter because unlike the QFuncPlxprInterpreter it + * does not allocate a new register upon beginning, + * does not deallocate the quantum register upon ending, + * and it does not release the quantum device back to the runtime. """ - def __init__(self, device, shots: qml.measurements.Shots | int): + def __init__(self, device, shots): self._device = device self._shots = self._extract_shots_value(shots) self.stateref = None @@ -298,29 +298,6 @@ def __setattr__(self, __name: str, __value) -> None: else: super().__setattr__(__name, __value) - def setup(self): - """Initialize the stateref and bind the device.""" - if self.stateref is None: - device_init_p.bind( - self._shots, - auto_qubit_management=(self._device.wires is None), - **_get_device_kwargs(self._device), - ) - self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} - - # pylint: disable=attribute-defined-outside-init - def cleanup(self): - """Perform any final steps after processing the plxpr. - - For conversion to calayst, this reinserts extracted qubits and - deallocates the register, and releases the device. - """ - if not self.actualized: - self.actualize_qreg() - qdealloc_p.bind(self.qreg) - device_release_p.bind() - self.stateref = None - def get_wire(self, wire_value) -> AbstractQbit: """Get the ``AbstractQbit`` corresponding to a wire value.""" if wire_value in self.wire_map: @@ -428,6 +405,83 @@ def _extract_shots_value(self, shots: qml.measurements.Shots | int): return shots.total_shots if shots else 0 + def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: + """Evaluate a jaxpr. + Args: + jaxpr (jax.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. + Returns: + list[TensorLike]: the results of the execution. + """ + raise NotImplementedError("Unreachable code until we add subroutine feature") + + """ + + # We assume we have at least one argument (the qreg) + assert len(args) > 0 + + self._parent_qreg = args[0] + self.stateref = {"qreg": self._parent_qreg, "wire_map": {}} + + # Send the original args (without the qreg) + outvals = super().eval(jaxpr, consts, *args) + + # Add the qreg to the output values + self.qreg, retvals = outvals[0], outvals[1:] + + self.actualize_qreg() + + outvals = (self.qreg, *retvals) + + self.stateref = None + + return outvals + """ + + +class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter): + """An interpreter that converts plxpr into catalyst-variant jaxpr. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) + + """ + + def __init__(self, device, shots: qml.measurements.Shots | int): + super().__init__(device, shots) + + def setup(self): + """Initialize the stateref and bind the device.""" + if self.stateref is None: + device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} + + # pylint: disable=attribute-defined-outside-init + def cleanup(self): + """Perform any final steps after processing the plxpr. + + For conversion to calayst, this reinserts extracted qubits and + deallocates the register, and releases the device. + """ + if not self.actualized: + self.actualize_qreg() + qdealloc_p.bind(self.qreg) + device_release_p.bind() + self.stateref = None + + def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: + """Use the PlxprInterpreter.eval method and not the SubroutineInterpreter + + The Subroutine's eval method expects the first argument to be the qreg. + This will not be the case when first evaluating a QFuncPlxprInterpreter + as the qreg will be available only after the function has started running. + It will be one of the first instructions in the function and it is + added by the setup function. + """ + return PlxprInterpreter.eval(self, jaxpr, consts, *args) + @QFuncPlxprInterpreter.register_primitive(qml.QubitUnitary._primitive) def handle_qubit_unitary(self, *invals, n_wires): From 28abc5885f179a9a60598066f0a37b40b08ac6d9 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Thu, 5 Jun 2025 18:54:35 +0000 Subject: [PATCH 2/5] function signature --- frontend/catalyst/from_plxpr.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index b3e27999a..7aa3586d2 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -274,9 +274,13 @@ class SubroutineInterpreter(PlxprInterpreter): * does not allocate a new register upon beginning, * does not deallocate the quantum register upon ending, * and it does not release the quantum device back to the runtime. + + Args: + device (qml.devices.Device) + shots (qml.measurements.Shots) """ - def __init__(self, device, shots): + def __init__(self, device, shots: qml.measurements.Shots | int): self._device = device self._shots = self._extract_shots_value(shots) self.stateref = None @@ -413,10 +417,6 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: *args (tuple[TensorLike]): The arguments for the jaxpr. Returns: list[TensorLike]: the results of the execution. - """ - raise NotImplementedError("Unreachable code until we add subroutine feature") - - """ # We assume we have at least one argument (the qreg) assert len(args) > 0 @@ -438,9 +438,10 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: Sequence, *args) -> list: return outvals """ + raise NotImplementedError("Unreachable code until we add subroutine feature") -class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter): +class QFuncPlxprInterpreter(SubroutineInterpreter): """An interpreter that converts plxpr into catalyst-variant jaxpr. Args: @@ -449,9 +450,6 @@ class QFuncPlxprInterpreter(SubroutineInterpreter, PlxprInterpreter): """ - def __init__(self, device, shots: qml.measurements.Shots | int): - super().__init__(device, shots) - def setup(self): """Initialize the stateref and bind the device.""" if self.stateref is None: From 3c561e515aa784626243b870c8f5b8da7a680839 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 11 Jun 2025 15:56:00 +0000 Subject: [PATCH 3/5] changelog --- doc/releases/changelog-dev.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index bf311fd38..22fc54c35 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -322,6 +322,10 @@ * The unused helper function `genArgMapFunction` in the `--lower-gradients` pass is removed. [(#1753)](https://github.com/PennyLaneAI/catalyst/pull/1753) +* Base components of QFuncPLxPRInterpreter have been moved into a base class called SubroutineInterpreter. + This is to reduce code duplication once we have support for quantum subroutines. + [(#1787)](https://github.com/PennyLaneAI/catalyst/pull/1787) + * The `qml.measure()` operation for mid-circuit measurements can now be used in QJIT-compiled circuits with program capture enabled. [(#1766)](https://github.com/PennyLaneAI/catalyst/pull/1766) From 0f2060835bfe58e938c609de43c0ad7a309f1398 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 11 Jun 2025 16:04:18 +0000 Subject: [PATCH 4/5] rebase fix --- frontend/catalyst/from_plxpr.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 7aa3586d2..4d94d6240 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -453,7 +453,11 @@ class QFuncPlxprInterpreter(SubroutineInterpreter): def setup(self): """Initialize the stateref and bind the device.""" if self.stateref is None: - device_init_p.bind(self._shots, **_get_device_kwargs(self._device)) + device_init_p.bind( + self._shots, + auto_qubit_management=(self._device.wires is None), + **_get_device_kwargs(self._device), + ) self.stateref = {"qreg": qalloc_p.bind(len(self._device.wires)), "wire_map": {}} # pylint: disable=attribute-defined-outside-init From 9e9e85256ad763b2cc4872a1582a72a23d49ebcc Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 11 Jun 2025 17:25:30 +0000 Subject: [PATCH 5/5] Add comment --- frontend/catalyst/from_plxpr.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/catalyst/from_plxpr.py b/frontend/catalyst/from_plxpr.py index 4d94d6240..5b49a2394 100644 --- a/frontend/catalyst/from_plxpr.py +++ b/frontend/catalyst/from_plxpr.py @@ -316,6 +316,9 @@ def actualize_qreg(self): """ self.actualized = True for orig_wire, wire in self.wire_map.items(): + # Note: since `getattr` checks specifically for qreg, we can't + # define qreg inside the init function. + # pylint: disable-next=attribute-defined-outside-init self.qreg = qinsert_p.bind(self.qreg, orig_wire, wire) def interpret_operation(self, op):