Skip to content

Commit 8f86430

Browse files
committed
proof of concept for capturing constants inside qnode
1 parent 07369d6 commit 8f86430

File tree

7 files changed

+35
-710
lines changed

7 files changed

+35
-710
lines changed

frontend/catalyst/api_extensions/differentiation.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -888,16 +888,3 @@ def _check_qnode_against_grad_method(f: QNode, method: str, jaxpr: Jaxpr):
888888
"Cannot differentiate a QNode explicitly marked non-differentiable (with "
889889
"diff_method=None)."
890890
)
891-
892-
if f.diff_method == "parameter-shift" and any(
893-
prim not in [expval_p, probs_p] for prim in return_ops
894-
):
895-
raise DifferentiableCompileError(
896-
"The parameter-shift method can only be used for QNodes "
897-
"which return either qml.expval or qml.probs."
898-
)
899-
900-
if f.diff_method == "adjoint" and any(prim not in [expval_p] for prim in return_ops):
901-
raise DifferentiableCompileError(
902-
"The adjoint method can only be used for QNodes which return qml.expval."
903-
)

frontend/catalyst/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def get_enforce_runtime_invariants_stage(_options: CompileOptions) -> List[str]:
173173
# keep inlining modules targeting the Catalyst runtime.
174174
# But qnodes targeting other backends may choose to lower
175175
# this into something else.
176+
"builtin.module(inline)",
176177
"inline-nested-module",
177178
]
178179
return enforce_runtime_invariants
@@ -217,7 +218,6 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
217218
"""Returns the list of passes that performs bufferization"""
218219
bufferization = [
219220
"one-shot-bufferize{dialect-filter=memref}",
220-
"inline",
221221
"gradient-preprocess",
222222
"gradient-bufferize",
223223
"scf-bufferize",

frontend/catalyst/qfunc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def __call__(self, *args, **kwargs):
131131
out_tree_expected = kwargs.pop("_out_tree_expected", [])
132132
debug_info = kwargs.pop("debug_info", None)
133133

134+
import jax
135+
136+
@jax.jit
134137
def _eval_quantum(*args, **kwargs):
135138
closed_jaxpr, out_type, out_tree, out_tree_exp = trace_quantum_function(
136139
self.func,

0 commit comments

Comments
 (0)