Skip to content

Crash when computing gradients of QNode that uses outer scope arguments #1608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mehrdad2m opened this issue Apr 3, 2025 · 0 comments
Open
Labels
bug Something isn't working

Comments

@mehrdad2m
Copy link
Contributor

The following test fails and the reason for failing is the use of an argument of the main qjitted function inside the qnode.

@qjit
def workflow(x, y):

    dev = qml.device("lightning.qubit", wires=1)

    @qml.qnode(dev)
    def circuit(x):
        qml.RX(jnp.pi * x , wires=0)
        qml.RX(jnp.pi * y , wires=0)
        return qml.expval(qml.PauliY(0))

    g = grad(circuit)
    return g(x)

print(workflow(1.0, 0.25))

Traceback:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in <module>
    print(workflow(1.0, 0.25))
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 602, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 675, in capture
    jaxpr, out_type, treedef, plugins = trace_to_jaxpr(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 550, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py", line 533, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/Users/mehrdad.malek/tmp/issue1339.py", line 70, in workflow
    return g(x)
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 688, in __call__
    results = grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mehrdad.malek/tmp/issue1339.py", line 72, in <module>
    print(workflow(1.0, 0.25))
          ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 529, in __call__
    requires_promotion = self.jit_compile(args, **kwargs)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 604, in jit_compile
    self.mlir_module = self.generate_ir()
                       ^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/debug/instruments.py", line 145, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jit.py", line 696, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_tracer.py", line 572, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 75, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/lowering.py", line 144, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1438, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
                           ^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1622, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 1730, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/frontend/catalyst/jax_primitives.py", line 526, in _grad_lowering
    nparray = np.asarray(const)
              ^^^^^^^^^^^^^^^^^
  File "/Users/mehrdad.malek/catalyst/venv/lib/python3.12/site-packages/jax/_src/core.py", line 650, in __array__
    raise TracerArrayConversionError(self)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
This DynamicJaxprTracer was created on line /Users/mehrdad.malek/catalyst/frontend/catalyst/jax_extras/tracing.py:533:38 (make_jaxpr2.<locals>.make_jaxpr_f)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Note that the failure only happens when getting the gradient of the quantum function. Therefore the following code works fine:

@qjit
def workflow(x, y):

    dev = qml.device("lightning.qubit", wires=1)

    @qml.qnode(dev)
    def circuit(x):
        qml.RX(jnp.pi * x , wires=0)
        qml.RX(jnp.pi * y , wires=0)
        return qml.expval(qml.PauliY(0))

    g = circuit
    return g(x)

print(workflow(1.0, 0.25))

returns

0.7071067811865475
@mehrdad2m mehrdad2m changed the title crash when computing gradients of QNode that uses outer scope arguments Crash when computing gradients of QNode that uses outer scope arguments Apr 3, 2025
@paul0403 paul0403 added the bug Something isn't working label Apr 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants