From 5d121d397adb8e75ee2686e66c291fc2a7e22a82 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Fri, 30 May 2025 16:25:20 +0000 Subject: [PATCH 1/2] make aot compilation more general --- frontend/catalyst/tracing/type_signatures.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 137154a9e6..ac077ace79 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -22,6 +22,7 @@ from typing import Callable import jax +from jax.core import AbstractValue from jax._src.core import shaped_abstractify from jax._src.interpreters.partial_eval import infer_lambda_input_type from jax._src.pjit import _flat_axes_specs @@ -56,7 +57,7 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, jax.core.ShapedArray)) for annotation in annotations) + return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations) def get_type_annotations(fn: Callable): From dc42cb88a4e064faf0eecd78075decf44f04ab4b Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Fri, 30 May 2025 16:33:50 +0000 Subject: [PATCH 2/2] style + changelog --- doc/releases/changelog-dev.md | 3 +++ frontend/catalyst/tracing/type_signatures.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f8b7311b9e..51824f759d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -241,6 +241,9 @@ This runtime stub is currently for mock execution only and should be treated as a placeholder operation. Internally, it functions just as a computational-basis measurement instruction. +* Quantum subroutine stub. + [(#1774)](https://github.com/PennyLaneAI/catalyst/pull/1774) + * PennyLane's arbitrary-basis measurement operations, such as :func:`qml.ftqc.measure_arbitrary_basis() `, are now QJIT-compatible with program capture enabled. diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index ac077ace79..b4090001c4 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -22,10 +22,10 @@ from typing import Callable import jax -from jax.core import AbstractValue from jax._src.core import shaped_abstractify from jax._src.interpreters.partial_eval import infer_lambda_input_type from jax._src.pjit import _flat_axes_specs +from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten from catalyst.jax_extras import get_aval2