diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f8b7311b9e..51824f759d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -241,6 +241,9 @@ This runtime stub is currently for mock execution only and should be treated as a placeholder operation. Internally, it functions just as a computational-basis measurement instruction. +* Quantum subroutine stub. + [(#1774)](https://github.com/PennyLaneAI/catalyst/pull/1774) + * PennyLane's arbitrary-basis measurement operations, such as :func:`qml.ftqc.measure_arbitrary_basis() `, are now QJIT-compatible with program capture enabled. diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 137154a9e6..b4090001c4 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -25,6 +25,7 @@ from jax._src.core import shaped_abstractify from jax._src.interpreters.partial_eval import infer_lambda_input_type from jax._src.pjit import _flat_axes_specs +from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten from catalyst.jax_extras import get_aval2 @@ -56,7 +57,7 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, jax.core.ShapedArray)) for annotation in annotations) + return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations) def get_type_annotations(fn: Callable):