Closed
Description
The following test fails and the reason for failing is the use of an argument of the main qjitted function inside the qnode.
@qjit
def workflow(x, y):
dev = qml.device("lightning.qubit", wires=1)
@qml.qnode(dev)
def circuit(x):
qml.RX(jnp.pi * x , wires=0)
qml.RX(jnp.pi * y , wires=0)
return qml.expval(qml.PauliY(0))
g = grad(circuit)
return g(x)
print(workflow(1.0, 0.25))
Traceback:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in <module>
print(workflow(1.0, 0.25))
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 602, in jit_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 675, in capture
jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 550, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 533, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
File "/Users/mehrdad.malek/tmp/issue1339.py", line 70, in workflow
return g(x)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in <module>
print(workflow(1.0, 0.25))
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 604, in jit_compile
self.mlir_module = self.generate_ir()
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 696, in generate_ir
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 572, in lower_jaxpr_to_mlir
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 75, in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 144, in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 526, in _grad_lowering
nparray = np.asarray(const)
^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in __array__
raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Note that the failure only happens when getting the gradient of the quantum function. Therefore the following code works fine:
@qjit
def workflow(x, y):
dev = qml.device("lightning.qubit", wires=1)
@qml.qnode(dev)
def circuit(x):
qml.RX(jnp.pi * x , wires=0)
qml.RX(jnp.pi * y , wires=0)
return qml.expval(qml.PauliY(0))
g = circuit
return g(x)
print(workflow(1.0, 0.25))
returns
0.7071067811865475