You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in<module>
print(workflow(1.0, 0.25))
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 602, in jit_compile
self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 675, in capture
jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 550, in trace_to_jaxpr
jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 533, in make_jaxpr_f
jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
File "/Users/mehrdad.malek/tmp/issue1339.py", line 70, in workflow
return g(x)
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in<module>
print(workflow(1.0, 0.25))
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
requires_promotion = self.jit_compile(args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 604, in jit_compile
self.mlir_module = self.generate_ir()
^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 696, in generate_ir
mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 572, in lower_jaxpr_to_mlir
mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 75, in jaxpr_to_mlir
module, context = custom_lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 144, in custom_lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 526, in _grad_lowering
nparray = np.asarray(const)
^^^^^^^^^^^^^^^^^
File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in __array__
raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Note that the failure only happens when getting the gradient of the quantum function. Therefore the following code works fine:
The text was updated successfully, but these errors were encountered:
mehrdad2m
changed the title
crash when computing gradients of QNode that uses outer scope arguments
Crash when computing gradients of QNode that uses outer scope arguments
Apr 3, 2025
The following test fails and the reason for failing is the use of an argument of the main qjitted function inside the qnode.
Traceback:
Note that the failure only happens when getting the gradient of the quantum function. Therefore the following code works fine:
returns
The text was updated successfully, but these errors were encountered: